Skip to content

BUG: scatter_matrix draws incorrect axis #9441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/source/whatsnew/v0.16.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ Performance Improvements
Bug Fixes
~~~~~~~~~

- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated.
- Fixed bug (:issue:`9542`) where labels did not appear properly in legend of ``DataFrame.plot()``. Passing ``label=`` args also now works, and series indices are no longer mutated.


- Bug in ``scatter_matrix`` draws unexpected axis ticklabels (:issue:`5662`)




Expand Down
26 changes: 23 additions & 3 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,10 +2353,9 @@ def test_scatter(self):
tm._skip_if_no_scipy()

df = DataFrame(randn(100, 2))
import pandas.tools.plotting as plt

def scat(**kwds):
return plt.scatter_matrix(df, **kwds)
return plotting.scatter_matrix(df, **kwds)

_check_plot_works(scat)
_check_plot_works(scat, marker='+')
Expand All @@ -2369,12 +2368,33 @@ def scat(**kwds):
_check_plot_works(scat, range_padding=.1)

def scat2(x, y, by=None, ax=None, figsize=None):
return plt.scatter_plot(df, x, y, by, ax, figsize=None)
return plotting.scatter_plot(df, x, y, by, ax, figsize=None)

_check_plot_works(scat2, 0, 1)
grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index)
_check_plot_works(scat2, 0, 1, by=grouper)

def test_scatter_matrix_axis(self):
tm._skip_if_no_scipy()
scatter_matrix = plotting.scatter_matrix

with tm.RNGContext(42):
df = DataFrame(randn(100, 3))

axes = _check_plot_works(scatter_matrix, df, range_padding=.1)
axes0_labels = axes[0][0].yaxis.get_majorticklabels()
# GH 5662
expected = ['-2', '-1', '0', '1', '2']
self._check_text_labels(axes0_labels, expected)
self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)

df[0] = ((df[0] - 2) / 3)
axes = _check_plot_works(scatter_matrix, df, range_padding=.1)
axes0_labels = axes[0][0].yaxis.get_majorticklabels()
expected = ['-1.2', '-1.0', '-0.8', '-0.6', '-0.4', '-0.2', '0.0']
self._check_text_labels(axes0_labels, expected)
self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)

@slow
def test_andrews_curves(self):
from pandas.tools.plotting import andrews_curves
Expand Down
45 changes: 16 additions & 29 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,45 +303,32 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
ax.set_xlim(boundaries_list[j])
ax.set_ylim(boundaries_list[i])

ax.set_xlabel('')
ax.set_ylabel('')

_label_axis(ax, kind='x', label=b, position='bottom', rotate=True)

_label_axis(ax, kind='y', label=a, position='left')
ax.set_xlabel(b)
ax.set_ylabel(a)

if j!= 0:
ax.yaxis.set_visible(False)
if i != n-1:
ax.xaxis.set_visible(False)

for ax in axes.flat:
setp(ax.get_xticklabels(), fontsize=8)
setp(ax.get_yticklabels(), fontsize=8)
if len(df.columns) > 1:
lim1 = boundaries_list[0]
locs = axes[0][1].yaxis.get_majorticklocs()
locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
adj = (locs - lim1[0]) / (lim1[1] - lim1[0])

return axes

def _label_axis(ax, kind='x', label='', position='top',
ticks=True, rotate=False):

from matplotlib.artist import setp
if kind == 'x':
ax.set_xlabel(label, visible=True)
ax.xaxis.set_visible(True)
ax.xaxis.set_ticks_position(position)
ax.xaxis.set_label_position(position)
if rotate:
setp(ax.get_xticklabels(), rotation=90)
elif kind == 'y':
ax.yaxis.set_visible(True)
ax.set_ylabel(label, visible=True)
# ax.set_ylabel(a)
ax.yaxis.set_ticks_position(position)
ax.yaxis.set_label_position(position)
return
lim0 = axes[0][0].get_ylim()
adj = adj * (lim0[1] - lim0[0]) + lim0[0]
axes[0][0].yaxis.set_ticks(adj)

if np.all(locs == locs.astype(int)):
# if all ticks are int
locs = locs.astype(int)
axes[0][0].yaxis.set_ticklabels(locs)

_set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)

return axes


def _gca():
Expand Down