Skip to content

CLN: more plotting test cleanups #4912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 21, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 37 additions & 57 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,11 @@ def _skip_if_no_scipy():
raise nose.SkipTest


@tm.mplskip
class TestSeriesPlots(unittest.TestCase):
@classmethod
def setUpClass(cls):
try:
import matplotlib as mpl
mpl.use('Agg', warn=False)
cls.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
except ImportError:
raise nose.SkipTest("matplotlib not installed")

def setUp(self):
import matplotlib as mpl
self.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
self.ts = tm.makeTimeSeries()
self.ts.name = 'ts'

Expand All @@ -50,8 +44,7 @@ def setUp(self):
self.iseries.name = 'iseries'

def tearDown(self):
import matplotlib.pyplot as plt
plt.close('all')
tm.close()

@slow
def test_plot(self):
Expand Down Expand Up @@ -352,24 +345,14 @@ def test_dup_datetime_index_plot(self):
_check_plot_works(s.plot)


@tm.mplskip
class TestDataFramePlots(unittest.TestCase):

@classmethod
def setUpClass(cls):
# import sys
# if 'IPython' in sys.modules:
# raise nose.SkipTest

try:
import matplotlib as mpl
mpl.use('Agg', warn=False)
cls.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
except ImportError:
raise nose.SkipTest("matplotlib not installed")
def setUp(self):
import matplotlib as mpl
self.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')

def tearDown(self):
import matplotlib.pyplot as plt
plt.close('all')
tm.close()

@slow
def test_plot(self):
Expand Down Expand Up @@ -949,19 +932,10 @@ def test_invalid_kind(self):
df.plot(kind='aasdf')


@tm.mplskip
class TestDataFrameGroupByPlots(unittest.TestCase):
@classmethod
def setUpClass(cls):
try:
import matplotlib as mpl
mpl.use('Agg', warn=False)
except ImportError:
raise nose.SkipTest

def tearDown(self):
import matplotlib.pyplot as plt
for fignum in plt.get_fignums():
plt.close(fignum)
tm.close()

@slow
def test_boxplot(self):
Expand Down Expand Up @@ -999,13 +973,16 @@ def test_time_series_plot_color_with_empty_kwargs(self):
import matplotlib as mpl

def_colors = mpl.rcParams['axes.color_cycle']
index = date_range('1/1/2000', periods=12)
s = Series(np.arange(1, 13), index=index)

ncolors = 3

for i in range(3):
ax = Series(np.arange(12) + 1, index=date_range('1/1/2000',
periods=12)).plot()
for i in range(ncolors):
ax = s.plot()

line_colors = [l.get_color() for l in ax.get_lines()]
self.assertEqual(line_colors, def_colors[:3])
self.assertEqual(line_colors, def_colors[:ncolors])

@slow
def test_grouped_hist(self):
Expand Down Expand Up @@ -1155,27 +1132,30 @@ def _check_plot_works(f, *args, **kwargs):
import matplotlib.pyplot as plt

try:
fig = kwargs['figure']
except KeyError:
fig = plt.gcf()
plt.clf()
ax = kwargs.get('ax', fig.add_subplot(211))
ret = f(*args, **kwargs)
try:
fig = kwargs['figure']
except KeyError:
fig = plt.gcf()

assert ret is not None
assert_is_valid_plot_return_object(ret)
plt.clf()

try:
kwargs['ax'] = fig.add_subplot(212)
ax = kwargs.get('ax', fig.add_subplot(211))
ret = f(*args, **kwargs)
except Exception:
pass
else:

assert_is_valid_plot_return_object(ret)

with ensure_clean() as path:
plt.savefig(path)
plt.close(fig)
try:
kwargs['ax'] = fig.add_subplot(212)
ret = f(*args, **kwargs)
except Exception:
pass
else:
assert_is_valid_plot_return_object(ret)

with ensure_clean() as path:
plt.savefig(path)
finally:
tm.close(fig)


def curpath():
Expand Down
41 changes: 29 additions & 12 deletions pandas/tests/test_rplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@
import nose


try:
import matplotlib.pyplot as plt
except:
raise nose.SkipTest


def curpath():
pth, _ = os.path.split(os.path.abspath(__file__))
return pth


def between(a, b, x):
"""Check if x is in the somewhere between a and b.
Expand All @@ -36,6 +31,8 @@ def between(a, b, x):
else:
return x <= a and x >= b


@tm.mplskip
class TestUtilityFunctions(unittest.TestCase):
"""
Tests for RPlot utility functions.
Expand Down Expand Up @@ -74,24 +71,25 @@ def test_dictionary_union(self):
self.assertTrue(2 in keys)
self.assertTrue(3 in keys)
self.assertTrue(4 in keys)
self.assertTrue(rplot.dictionary_union(dict1, {}) == dict1)
self.assertTrue(rplot.dictionary_union({}, dict1) == dict1)
self.assertTrue(rplot.dictionary_union({}, {}) == {})
self.assertEqual(rplot.dictionary_union(dict1, {}), dict1)
self.assertEqual(rplot.dictionary_union({}, dict1), dict1)
self.assertEqual(rplot.dictionary_union({}, {}), {})

def test_merge_aes(self):
layer1 = rplot.Layer(size=rplot.ScaleSize('test'))
layer2 = rplot.Layer(shape=rplot.ScaleShape('test'))
rplot.merge_aes(layer1, layer2)
self.assertTrue(isinstance(layer2.aes['size'], rplot.ScaleSize))
self.assertTrue(isinstance(layer2.aes['shape'], rplot.ScaleShape))
self.assertTrue(layer2.aes['size'] == layer1.aes['size'])
self.assertEqual(layer2.aes['size'], layer1.aes['size'])
for key in layer2.aes.keys():
if key != 'size' and key != 'shape':
self.assertTrue(layer2.aes[key] is None)

def test_sequence_layers(self):
layer1 = rplot.Layer(self.data)
layer2 = rplot.GeomPoint(x='SepalLength', y='SepalWidth', size=rplot.ScaleSize('PetalLength'))
layer2 = rplot.GeomPoint(x='SepalLength', y='SepalWidth',
size=rplot.ScaleSize('PetalLength'))
layer3 = rplot.GeomPolyFit(2)
result = rplot.sequence_layers([layer1, layer2, layer3])
self.assertEqual(len(result), 3)
Expand All @@ -102,6 +100,8 @@ def test_sequence_layers(self):
self.assertTrue(self.data is last.data)
self.assertTrue(rplot.sequence_layers([layer1])[0] is layer1)


@tm.mplskip
class TestTrellis(unittest.TestCase):
def setUp(self):
path = os.path.join(curpath(), 'data/tips.csv')
Expand Down Expand Up @@ -148,11 +148,15 @@ def test_trellis_cols_rows(self):
self.assertEqual(self.trellis3.cols, 2)
self.assertEqual(self.trellis3.rows, 1)


@tm.mplskip
class TestScaleGradient(unittest.TestCase):
def setUp(self):
path = os.path.join(curpath(), 'data/iris.csv')
self.data = read_csv(path, sep=',')
self.gradient = rplot.ScaleGradient("SepalLength", colour1=(0.2, 0.3, 0.4), colour2=(0.8, 0.7, 0.6))
self.gradient = rplot.ScaleGradient("SepalLength", colour1=(0.2, 0.3,
0.4),
colour2=(0.8, 0.7, 0.6))

def test_gradient(self):
for index in range(len(self.data)):
Expand All @@ -164,6 +168,8 @@ def test_gradient(self):
self.assertTrue(between(g1, g2, g))
self.assertTrue(between(b1, b2, b))


@tm.mplskip
class TestScaleGradient2(unittest.TestCase):
def setUp(self):
path = os.path.join(curpath(), 'data/iris.csv')
Expand All @@ -190,6 +196,8 @@ def test_gradient2(self):
self.assertTrue(between(g2, g3, g))
self.assertTrue(between(b2, b3, b))


@tm.mplskip
class TestScaleRandomColour(unittest.TestCase):
def setUp(self):
path = os.path.join(curpath(), 'data/iris.csv')
Expand All @@ -208,13 +216,16 @@ def test_random_colour(self):
self.assertTrue(g <= 1.0)
self.assertTrue(b <= 1.0)


@tm.mplskip
class TestScaleConstant(unittest.TestCase):
def test_scale_constant(self):
scale = rplot.ScaleConstant(1.0)
self.assertEqual(scale(None, None), 1.0)
scale = rplot.ScaleConstant("test")
self.assertEqual(scale(None, None), "test")


class TestScaleSize(unittest.TestCase):
def setUp(self):
path = os.path.join(curpath(), 'data/iris.csv')
Expand All @@ -235,8 +246,10 @@ def f():
self.assertRaises(ValueError, f)


@tm.mplskip
class TestRPlot(unittest.TestCase):
def test_rplot1(self):
import matplotlib.pyplot as plt
path = os.path.join(curpath(), 'data/tips.csv')
plt.figure()
self.data = read_csv(path, sep=',')
Expand All @@ -247,6 +260,7 @@ def test_rplot1(self):
self.plot.render(self.fig)

def test_rplot2(self):
import matplotlib.pyplot as plt
path = os.path.join(curpath(), 'data/tips.csv')
plt.figure()
self.data = read_csv(path, sep=',')
Expand All @@ -257,6 +271,7 @@ def test_rplot2(self):
self.plot.render(self.fig)

def test_rplot3(self):
import matplotlib.pyplot as plt
path = os.path.join(curpath(), 'data/tips.csv')
plt.figure()
self.data = read_csv(path, sep=',')
Expand All @@ -267,6 +282,7 @@ def test_rplot3(self):
self.plot.render(self.fig)

def test_rplot_iris(self):
import matplotlib.pyplot as plt
path = os.path.join(curpath(), 'data/iris.csv')
plt.figure()
self.data = read_csv(path, sep=',')
Expand All @@ -277,5 +293,6 @@ def test_rplot_iris(self):
self.fig = plt.gcf()
plot.render(self.fig)


if __name__ == '__main__':
unittest.main()
14 changes: 2 additions & 12 deletions pandas/tseries/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,8 @@ def _skip_if_no_scipy():
raise nose.SkipTest


@tm.mplskip
class TestTSPlot(unittest.TestCase):

@classmethod
def setUpClass(cls):
try:
import matplotlib as mpl
mpl.use('Agg', warn=False)
except ImportError:
raise nose.SkipTest

def setUp(self):
freq = ['S', 'T', 'H', 'D', 'W', 'M', 'Q', 'Y']
idx = [period_range('12/31/1999', freq=x, periods=100) for x in freq]
Expand All @@ -52,9 +44,7 @@ def setUp(self):
for x in idx]

def tearDown(self):
import matplotlib.pyplot as plt
for fignum in plt.get_fignums():
plt.close(fignum)
tm.close()

@slow
def test_ts_plot_with_tz(self):
Expand Down
27 changes: 27 additions & 0 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from numpy.random import randn, rand
import numpy as np

import nose

from pandas.core.common import isnull, _is_sequence
import pandas.core.index as index
import pandas.core.series as series
Expand Down Expand Up @@ -70,6 +72,31 @@ def choice(x, size=10):
except AttributeError:
return np.random.randint(len(x), size=size).choose(x)


def close(fignum=None):
from matplotlib.pyplot import get_fignums, close as _close

if fignum is None:
for fignum in get_fignums():
_close(fignum)
else:
_close(fignum)


def mplskip(cls):
"""Skip a TestCase instance if matplotlib isn't installed"""
@classmethod
def setUpClass(cls):
try:
import matplotlib as mpl
mpl.use("Agg", warn=False)
except ImportError:
raise nose.SkipTest("matplotlib not installed")

cls.setUpClass = setUpClass
return cls


#------------------------------------------------------------------------------
# Console debugging tools

Expand Down