From 7b7474e717dec0fc953ffcbcda13ee6e052bf82d Mon Sep 17 00:00:00 2001 From: Phillip Cloud Date: Fri, 20 Sep 2013 22:26:46 -0400 Subject: [PATCH] CLN: more plotting test cleanups --- pandas/tests/test_graphics.py | 94 +++++++++++---------------- pandas/tests/test_rplot.py | 41 ++++++++---- pandas/tseries/tests/test_plotting.py | 14 +--- pandas/util/testing.py | 27 ++++++++ 4 files changed, 95 insertions(+), 81 deletions(-) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index cb6ec3d648afa..558bf17b0cd5c 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -29,17 +29,11 @@ def _skip_if_no_scipy(): raise nose.SkipTest +@tm.mplskip class TestSeriesPlots(unittest.TestCase): - @classmethod - def setUpClass(cls): - try: - import matplotlib as mpl - mpl.use('Agg', warn=False) - cls.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1') - except ImportError: - raise nose.SkipTest("matplotlib not installed") - def setUp(self): + import matplotlib as mpl + self.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1') self.ts = tm.makeTimeSeries() self.ts.name = 'ts' @@ -50,8 +44,7 @@ def setUp(self): self.iseries.name = 'iseries' def tearDown(self): - import matplotlib.pyplot as plt - plt.close('all') + tm.close() @slow def test_plot(self): @@ -352,24 +345,14 @@ def test_dup_datetime_index_plot(self): _check_plot_works(s.plot) +@tm.mplskip class TestDataFramePlots(unittest.TestCase): - - @classmethod - def setUpClass(cls): - # import sys - # if 'IPython' in sys.modules: - # raise nose.SkipTest - - try: - import matplotlib as mpl - mpl.use('Agg', warn=False) - cls.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1') - except ImportError: - raise nose.SkipTest("matplotlib not installed") + def setUp(self): + import matplotlib as mpl + self.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1') def tearDown(self): - import matplotlib.pyplot as plt - plt.close('all') + tm.close() @slow def test_plot(self): @@ -949,19 +932,10 @@ def test_invalid_kind(self): df.plot(kind='aasdf') +@tm.mplskip class TestDataFrameGroupByPlots(unittest.TestCase): - @classmethod - def setUpClass(cls): - try: - import matplotlib as mpl - mpl.use('Agg', warn=False) - except ImportError: - raise nose.SkipTest - def tearDown(self): - import matplotlib.pyplot as plt - for fignum in plt.get_fignums(): - plt.close(fignum) + tm.close() @slow def test_boxplot(self): @@ -999,13 +973,16 @@ def test_time_series_plot_color_with_empty_kwargs(self): import matplotlib as mpl def_colors = mpl.rcParams['axes.color_cycle'] + index = date_range('1/1/2000', periods=12) + s = Series(np.arange(1, 13), index=index) + + ncolors = 3 - for i in range(3): - ax = Series(np.arange(12) + 1, index=date_range('1/1/2000', - periods=12)).plot() + for i in range(ncolors): + ax = s.plot() line_colors = [l.get_color() for l in ax.get_lines()] - self.assertEqual(line_colors, def_colors[:3]) + self.assertEqual(line_colors, def_colors[:ncolors]) @slow def test_grouped_hist(self): @@ -1155,27 +1132,30 @@ def _check_plot_works(f, *args, **kwargs): import matplotlib.pyplot as plt try: - fig = kwargs['figure'] - except KeyError: - fig = plt.gcf() - plt.clf() - ax = kwargs.get('ax', fig.add_subplot(211)) - ret = f(*args, **kwargs) + try: + fig = kwargs['figure'] + except KeyError: + fig = plt.gcf() - assert ret is not None - assert_is_valid_plot_return_object(ret) + plt.clf() - try: - kwargs['ax'] = fig.add_subplot(212) + ax = kwargs.get('ax', fig.add_subplot(211)) ret = f(*args, **kwargs) - except Exception: - pass - else: + assert_is_valid_plot_return_object(ret) - with ensure_clean() as path: - plt.savefig(path) - plt.close(fig) + try: + kwargs['ax'] = fig.add_subplot(212) + ret = f(*args, **kwargs) + except Exception: + pass + else: + assert_is_valid_plot_return_object(ret) + + with ensure_clean() as path: + plt.savefig(path) + finally: + tm.close(fig) def curpath(): diff --git a/pandas/tests/test_rplot.py b/pandas/tests/test_rplot.py index e7faa8f25deb3..d59b182b77d4c 100644 --- a/pandas/tests/test_rplot.py +++ b/pandas/tests/test_rplot.py @@ -8,16 +8,11 @@ import nose -try: - import matplotlib.pyplot as plt -except: - raise nose.SkipTest - - def curpath(): pth, _ = os.path.split(os.path.abspath(__file__)) return pth + def between(a, b, x): """Check if x is in the somewhere between a and b. @@ -36,6 +31,8 @@ def between(a, b, x): else: return x <= a and x >= b + +@tm.mplskip class TestUtilityFunctions(unittest.TestCase): """ Tests for RPlot utility functions. @@ -74,9 +71,9 @@ def test_dictionary_union(self): self.assertTrue(2 in keys) self.assertTrue(3 in keys) self.assertTrue(4 in keys) - self.assertTrue(rplot.dictionary_union(dict1, {}) == dict1) - self.assertTrue(rplot.dictionary_union({}, dict1) == dict1) - self.assertTrue(rplot.dictionary_union({}, {}) == {}) + self.assertEqual(rplot.dictionary_union(dict1, {}), dict1) + self.assertEqual(rplot.dictionary_union({}, dict1), dict1) + self.assertEqual(rplot.dictionary_union({}, {}), {}) def test_merge_aes(self): layer1 = rplot.Layer(size=rplot.ScaleSize('test')) @@ -84,14 +81,15 @@ def test_merge_aes(self): rplot.merge_aes(layer1, layer2) self.assertTrue(isinstance(layer2.aes['size'], rplot.ScaleSize)) self.assertTrue(isinstance(layer2.aes['shape'], rplot.ScaleShape)) - self.assertTrue(layer2.aes['size'] == layer1.aes['size']) + self.assertEqual(layer2.aes['size'], layer1.aes['size']) for key in layer2.aes.keys(): if key != 'size' and key != 'shape': self.assertTrue(layer2.aes[key] is None) def test_sequence_layers(self): layer1 = rplot.Layer(self.data) - layer2 = rplot.GeomPoint(x='SepalLength', y='SepalWidth', size=rplot.ScaleSize('PetalLength')) + layer2 = rplot.GeomPoint(x='SepalLength', y='SepalWidth', + size=rplot.ScaleSize('PetalLength')) layer3 = rplot.GeomPolyFit(2) result = rplot.sequence_layers([layer1, layer2, layer3]) self.assertEqual(len(result), 3) @@ -102,6 +100,8 @@ def test_sequence_layers(self): self.assertTrue(self.data is last.data) self.assertTrue(rplot.sequence_layers([layer1])[0] is layer1) + +@tm.mplskip class TestTrellis(unittest.TestCase): def setUp(self): path = os.path.join(curpath(), 'data/tips.csv') @@ -148,11 +148,15 @@ def test_trellis_cols_rows(self): self.assertEqual(self.trellis3.cols, 2) self.assertEqual(self.trellis3.rows, 1) + +@tm.mplskip class TestScaleGradient(unittest.TestCase): def setUp(self): path = os.path.join(curpath(), 'data/iris.csv') self.data = read_csv(path, sep=',') - self.gradient = rplot.ScaleGradient("SepalLength", colour1=(0.2, 0.3, 0.4), colour2=(0.8, 0.7, 0.6)) + self.gradient = rplot.ScaleGradient("SepalLength", colour1=(0.2, 0.3, + 0.4), + colour2=(0.8, 0.7, 0.6)) def test_gradient(self): for index in range(len(self.data)): @@ -164,6 +168,8 @@ def test_gradient(self): self.assertTrue(between(g1, g2, g)) self.assertTrue(between(b1, b2, b)) + +@tm.mplskip class TestScaleGradient2(unittest.TestCase): def setUp(self): path = os.path.join(curpath(), 'data/iris.csv') @@ -190,6 +196,8 @@ def test_gradient2(self): self.assertTrue(between(g2, g3, g)) self.assertTrue(between(b2, b3, b)) + +@tm.mplskip class TestScaleRandomColour(unittest.TestCase): def setUp(self): path = os.path.join(curpath(), 'data/iris.csv') @@ -208,6 +216,8 @@ def test_random_colour(self): self.assertTrue(g <= 1.0) self.assertTrue(b <= 1.0) + +@tm.mplskip class TestScaleConstant(unittest.TestCase): def test_scale_constant(self): scale = rplot.ScaleConstant(1.0) @@ -215,6 +225,7 @@ def test_scale_constant(self): scale = rplot.ScaleConstant("test") self.assertEqual(scale(None, None), "test") + class TestScaleSize(unittest.TestCase): def setUp(self): path = os.path.join(curpath(), 'data/iris.csv') @@ -235,8 +246,10 @@ def f(): self.assertRaises(ValueError, f) +@tm.mplskip class TestRPlot(unittest.TestCase): def test_rplot1(self): + import matplotlib.pyplot as plt path = os.path.join(curpath(), 'data/tips.csv') plt.figure() self.data = read_csv(path, sep=',') @@ -247,6 +260,7 @@ def test_rplot1(self): self.plot.render(self.fig) def test_rplot2(self): + import matplotlib.pyplot as plt path = os.path.join(curpath(), 'data/tips.csv') plt.figure() self.data = read_csv(path, sep=',') @@ -257,6 +271,7 @@ def test_rplot2(self): self.plot.render(self.fig) def test_rplot3(self): + import matplotlib.pyplot as plt path = os.path.join(curpath(), 'data/tips.csv') plt.figure() self.data = read_csv(path, sep=',') @@ -267,6 +282,7 @@ def test_rplot3(self): self.plot.render(self.fig) def test_rplot_iris(self): + import matplotlib.pyplot as plt path = os.path.join(curpath(), 'data/iris.csv') plt.figure() self.data = read_csv(path, sep=',') @@ -277,5 +293,6 @@ def test_rplot_iris(self): self.fig = plt.gcf() plot.render(self.fig) + if __name__ == '__main__': unittest.main() diff --git a/pandas/tseries/tests/test_plotting.py b/pandas/tseries/tests/test_plotting.py index a22d2a65248a9..96888df114950 100644 --- a/pandas/tseries/tests/test_plotting.py +++ b/pandas/tseries/tests/test_plotting.py @@ -26,16 +26,8 @@ def _skip_if_no_scipy(): raise nose.SkipTest +@tm.mplskip class TestTSPlot(unittest.TestCase): - - @classmethod - def setUpClass(cls): - try: - import matplotlib as mpl - mpl.use('Agg', warn=False) - except ImportError: - raise nose.SkipTest - def setUp(self): freq = ['S', 'T', 'H', 'D', 'W', 'M', 'Q', 'Y'] idx = [period_range('12/31/1999', freq=x, periods=100) for x in freq] @@ -52,9 +44,7 @@ def setUp(self): for x in idx] def tearDown(self): - import matplotlib.pyplot as plt - for fignum in plt.get_fignums(): - plt.close(fignum) + tm.close() @slow def test_ts_plot_with_tz(self): diff --git a/pandas/util/testing.py b/pandas/util/testing.py index e7e930320116b..a5a96d3e03cac 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -18,6 +18,8 @@ from numpy.random import randn, rand import numpy as np +import nose + from pandas.core.common import isnull, _is_sequence import pandas.core.index as index import pandas.core.series as series @@ -70,6 +72,31 @@ def choice(x, size=10): except AttributeError: return np.random.randint(len(x), size=size).choose(x) + +def close(fignum=None): + from matplotlib.pyplot import get_fignums, close as _close + + if fignum is None: + for fignum in get_fignums(): + _close(fignum) + else: + _close(fignum) + + +def mplskip(cls): + """Skip a TestCase instance if matplotlib isn't installed""" + @classmethod + def setUpClass(cls): + try: + import matplotlib as mpl + mpl.use("Agg", warn=False) + except ImportError: + raise nose.SkipTest("matplotlib not installed") + + cls.setUpClass = setUpClass + return cls + + #------------------------------------------------------------------------------ # Console debugging tools