Skip to content

Commit cb8c130

Browse files
author
Tom Augspurger
committed
Merge pull request #9441 from sinhrks/sm_axis
BUG: scatter_matrix draws incorrect axis
2 parents 587e410 + 0bfbe62 commit cb8c130

File tree

3 files changed

+42
-33
lines changed

3 files changed

+42
-33
lines changed

doc/source/whatsnew/v0.16.1.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ Performance Improvements
5858
Bug Fixes
5959
~~~~~~~~~
6060

61-
- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated.
61+
- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated.
6262

6363

64+
- Bug in ``scatter_matrix`` draws unexpected axis ticklabels (:issue:`5662`)
65+
6466

6567

6668

pandas/tests/test_graphics.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -2353,10 +2353,9 @@ def test_scatter(self):
23532353
tm._skip_if_no_scipy()
23542354

23552355
df = DataFrame(randn(100, 2))
2356-
import pandas.tools.plotting as plt
23572356

23582357
def scat(**kwds):
2359-
return plt.scatter_matrix(df, **kwds)
2358+
return plotting.scatter_matrix(df, **kwds)
23602359

23612360
_check_plot_works(scat)
23622361
_check_plot_works(scat, marker='+')
@@ -2369,12 +2368,33 @@ def scat(**kwds):
23692368
_check_plot_works(scat, range_padding=.1)
23702369

23712370
def scat2(x, y, by=None, ax=None, figsize=None):
2372-
return plt.scatter_plot(df, x, y, by, ax, figsize=None)
2371+
return plotting.scatter_plot(df, x, y, by, ax, figsize=None)
23732372

23742373
_check_plot_works(scat2, 0, 1)
23752374
grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index)
23762375
_check_plot_works(scat2, 0, 1, by=grouper)
23772376

2377+
def test_scatter_matrix_axis(self):
2378+
tm._skip_if_no_scipy()
2379+
scatter_matrix = plotting.scatter_matrix
2380+
2381+
with tm.RNGContext(42):
2382+
df = DataFrame(randn(100, 3))
2383+
2384+
axes = _check_plot_works(scatter_matrix, df, range_padding=.1)
2385+
axes0_labels = axes[0][0].yaxis.get_majorticklabels()
2386+
# GH 5662
2387+
expected = ['-2', '-1', '0', '1', '2']
2388+
self._check_text_labels(axes0_labels, expected)
2389+
self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
2390+
2391+
df[0] = ((df[0] - 2) / 3)
2392+
axes = _check_plot_works(scatter_matrix, df, range_padding=.1)
2393+
axes0_labels = axes[0][0].yaxis.get_majorticklabels()
2394+
expected = ['-1.2', '-1.0', '-0.8', '-0.6', '-0.4', '-0.2', '0.0']
2395+
self._check_text_labels(axes0_labels, expected)
2396+
self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
2397+
23782398
@slow
23792399
def test_andrews_curves(self):
23802400
from pandas.tools.plotting import andrews_curves

pandas/tools/plotting.py

+16-29
Original file line numberDiff line numberDiff line change
@@ -303,45 +303,32 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
303303
ax.set_xlim(boundaries_list[j])
304304
ax.set_ylim(boundaries_list[i])
305305

306-
ax.set_xlabel('')
307-
ax.set_ylabel('')
308-
309-
_label_axis(ax, kind='x', label=b, position='bottom', rotate=True)
310-
311-
_label_axis(ax, kind='y', label=a, position='left')
306+
ax.set_xlabel(b)
307+
ax.set_ylabel(a)
312308

313309
if j!= 0:
314310
ax.yaxis.set_visible(False)
315311
if i != n-1:
316312
ax.xaxis.set_visible(False)
317313

318-
for ax in axes.flat:
319-
setp(ax.get_xticklabels(), fontsize=8)
320-
setp(ax.get_yticklabels(), fontsize=8)
314+
if len(df.columns) > 1:
315+
lim1 = boundaries_list[0]
316+
locs = axes[0][1].yaxis.get_majorticklocs()
317+
locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
318+
adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
321319

322-
return axes
323-
324-
def _label_axis(ax, kind='x', label='', position='top',
325-
ticks=True, rotate=False):
326-
327-
from matplotlib.artist import setp
328-
if kind == 'x':
329-
ax.set_xlabel(label, visible=True)
330-
ax.xaxis.set_visible(True)
331-
ax.xaxis.set_ticks_position(position)
332-
ax.xaxis.set_label_position(position)
333-
if rotate:
334-
setp(ax.get_xticklabels(), rotation=90)
335-
elif kind == 'y':
336-
ax.yaxis.set_visible(True)
337-
ax.set_ylabel(label, visible=True)
338-
# ax.set_ylabel(a)
339-
ax.yaxis.set_ticks_position(position)
340-
ax.yaxis.set_label_position(position)
341-
return
320+
lim0 = axes[0][0].get_ylim()
321+
adj = adj * (lim0[1] - lim0[0]) + lim0[0]
322+
axes[0][0].yaxis.set_ticks(adj)
342323

324+
if np.all(locs == locs.astype(int)):
325+
# if all ticks are int
326+
locs = locs.astype(int)
327+
axes[0][0].yaxis.set_ticklabels(locs)
343328

329+
_set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
344330

331+
return axes
345332

346333

347334
def _gca():

0 commit comments

Comments
 (0)