From 0bfbe62c31aeb669a516b971b2ce957aafd008c0 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sat, 7 Feb 2015 11:10:33 +0900 Subject: [PATCH] BUG: scatter_matrix draws incorrect axis --- doc/source/whatsnew/v0.16.1.txt | 4 ++- pandas/tests/test_graphics.py | 26 ++++++++++++++++--- pandas/tools/plotting.py | 45 ++++++++++++--------------------- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/doc/source/whatsnew/v0.16.1.txt b/doc/source/whatsnew/v0.16.1.txt index 55922091556c1..ca316bbac8474 100644 --- a/doc/source/whatsnew/v0.16.1.txt +++ b/doc/source/whatsnew/v0.16.1.txt @@ -58,9 +58,11 @@ Performance Improvements Bug Fixes ~~~~~~~~~ -- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated. +- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated. +- Bug in ``scatter_matrix`` draws unexpected axis ticklabels (:issue:`5662`) + diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index a7b82559205cd..04e43fabcc1cc 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -2353,10 +2353,9 @@ def test_scatter(self): tm._skip_if_no_scipy() df = DataFrame(randn(100, 2)) - import pandas.tools.plotting as plt def scat(**kwds): - return plt.scatter_matrix(df, **kwds) + return plotting.scatter_matrix(df, **kwds) _check_plot_works(scat) _check_plot_works(scat, marker='+') @@ -2369,12 +2368,33 @@ def scat(**kwds): _check_plot_works(scat, range_padding=.1) def scat2(x, y, by=None, ax=None, figsize=None): - return plt.scatter_plot(df, x, y, by, ax, figsize=None) + return plotting.scatter_plot(df, x, y, by, ax, figsize=None) _check_plot_works(scat2, 0, 1) grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index) _check_plot_works(scat2, 0, 1, by=grouper) + def test_scatter_matrix_axis(self): + tm._skip_if_no_scipy() + scatter_matrix = plotting.scatter_matrix + + with tm.RNGContext(42): + df = DataFrame(randn(100, 3)) + + axes = _check_plot_works(scatter_matrix, df, range_padding=.1) + axes0_labels = axes[0][0].yaxis.get_majorticklabels() + # GH 5662 + expected = ['-2', '-1', '0', '1', '2'] + self._check_text_labels(axes0_labels, expected) + self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + + df[0] = ((df[0] - 2) / 3) + axes = _check_plot_works(scatter_matrix, df, range_padding=.1) + axes0_labels = axes[0][0].yaxis.get_majorticklabels() + expected = ['-1.2', '-1.0', '-0.8', '-0.6', '-0.4', '-0.2', '0.0'] + self._check_text_labels(axes0_labels, expected) + self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + @slow def test_andrews_curves(self): from pandas.tools.plotting import andrews_curves diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index cc9959d2c0efa..c7130a144adea 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -303,45 +303,32 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, ax.set_xlim(boundaries_list[j]) ax.set_ylim(boundaries_list[i]) - ax.set_xlabel('') - ax.set_ylabel('') - - _label_axis(ax, kind='x', label=b, position='bottom', rotate=True) - - _label_axis(ax, kind='y', label=a, position='left') + ax.set_xlabel(b) + ax.set_ylabel(a) if j!= 0: ax.yaxis.set_visible(False) if i != n-1: ax.xaxis.set_visible(False) - for ax in axes.flat: - setp(ax.get_xticklabels(), fontsize=8) - setp(ax.get_yticklabels(), fontsize=8) + if len(df.columns) > 1: + lim1 = boundaries_list[0] + locs = axes[0][1].yaxis.get_majorticklocs() + locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])] + adj = (locs - lim1[0]) / (lim1[1] - lim1[0]) - return axes - -def _label_axis(ax, kind='x', label='', position='top', - ticks=True, rotate=False): - - from matplotlib.artist import setp - if kind == 'x': - ax.set_xlabel(label, visible=True) - ax.xaxis.set_visible(True) - ax.xaxis.set_ticks_position(position) - ax.xaxis.set_label_position(position) - if rotate: - setp(ax.get_xticklabels(), rotation=90) - elif kind == 'y': - ax.yaxis.set_visible(True) - ax.set_ylabel(label, visible=True) - # ax.set_ylabel(a) - ax.yaxis.set_ticks_position(position) - ax.yaxis.set_label_position(position) - return + lim0 = axes[0][0].get_ylim() + adj = adj * (lim0[1] - lim0[0]) + lim0[0] + axes[0][0].yaxis.set_ticks(adj) + if np.all(locs == locs.astype(int)): + # if all ticks are int + locs = locs.astype(int) + axes[0][0].yaxis.set_ticklabels(locs) + _set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + return axes def _gca():