Skip to content

BUG: allow numpy.array as c values to scatterplot #8929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 3, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/source/whatsnew/v0.15.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ Bug Fixes
and the last offset is not calculated from the start of the range (:issue:`8683`)



- Bug where DataFrame.plot(kind='scatter') fails when checking if an np.array is in the DataFrame (:issue:`8852`)



- Bug in `pd.infer_freq`/`DataFrame.inferred_freq` that prevented proper sub-daily frequency inference
when the index contained DST days (:issue:`8772`).
- Bug where index name was still used when plotting a series with ``use_index=False`` (:issue:`8558`).
Expand Down
32 changes: 32 additions & 0 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,6 +2504,38 @@ def is_list_like(arg):
not isinstance(arg, compat.string_and_binary_types))


def is_hashable(arg):
"""Return True if hash(arg) will succeed, False otherwise.

Some types will pass a test against collections.Hashable but fail when they
are actually hashed with hash().

Distinguish between these and other types by trying the call to hash() and
seeing if they raise TypeError.

Examples
--------
>>> a = ([],)
>>> isinstance(a, collections.Hashable)
True
>>> is_hashable(a)
False
"""
# don't consider anything not collections.Hashable, so as not to broaden
# the definition of hashable beyond that. For example, old-style classes
# are not collections.Hashable but they won't fail hash().
if not isinstance(arg, collections.Hashable):
return False

# narrow the definition of hashable if hash(arg) fails in practice
try:
hash(arg)
except TypeError:
return False
else:
return True


def is_sequence(x):
try:
iter(x)
Expand Down
51 changes: 51 additions & 0 deletions pandas/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
from datetime import datetime
import re
import sys

import nose
from nose.tools import assert_equal
Expand Down Expand Up @@ -398,6 +400,55 @@ def test_is_list_like():
assert not com.is_list_like(f)


def test_is_hashable():

# all new-style classes are hashable by default
class HashableClass(object):
pass

class UnhashableClass1(object):
__hash__ = None

class UnhashableClass2(object):
def __hash__(self):
raise TypeError("Not hashable")

hashable = (
1, 'a', tuple(), (1,), HashableClass(),
)
not_hashable = (
[], UnhashableClass1(),
)
abc_hashable_not_really_hashable = (
([],), UnhashableClass2(),
)

for i in hashable:
assert isinstance(i, collections.Hashable)
assert com.is_hashable(i)
for i in not_hashable:
assert not isinstance(i, collections.Hashable)
assert not com.is_hashable(i)
for i in abc_hashable_not_really_hashable:
assert isinstance(i, collections.Hashable)
assert not com.is_hashable(i)

# numpy.array is no longer collections.Hashable as of
# https://github.com/numpy/numpy/pull/5326, just test
# pandas.common.is_hashable()
assert not com.is_hashable(np.array([]))

# old-style classes in Python 2 don't appear hashable to
# collections.Hashable but also seem to support hash() by default
if sys.version_info[0] == 2:
class OldStyleClass():
pass
c = OldStyleClass()
assert not isinstance(c, collections.Hashable)
assert not com.is_hashable(c)
hash(c) # this will not raise


def test_ensure_int32():
values = np.arange(10, dtype=np.int32)
result = com._ensure_int32(values)
Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,31 @@ def test_plot_scatter_with_c(self):
self.assertIs(ax.collections[0].colorbar, None)
self._check_colors(ax.collections, facecolors=['r'])

# Ensure that we can pass an np.array straight through to matplotlib,
# this functionality was accidentally removed previously.
# See https://github.com/pydata/pandas/issues/8852 for bug report
#
# Exercise colormap path and non-colormap path as they are independent
#
df = DataFrame({'A': [1, 2], 'B': [3, 4]})
red_rgba = [1.0, 0.0, 0.0, 1.0]
green_rgba = [0.0, 1.0, 0.0, 1.0]
rgba_array = np.array([red_rgba, green_rgba])
ax = df.plot(kind='scatter', x='A', y='B', c=rgba_array)
# expect the face colors of the points in the non-colormap path to be
# identical to the values we supplied, normally we'd be on shaky ground
# comparing floats for equality but here we expect them to be
# identical.
self.assertTrue(
np.array_equal(
ax.collections[0].get_facecolor(),
rgba_array))
# we don't test the colors of the faces in this next plot because they
# are dependent on the spring colormap, which may change its colors
# later.
float_array = np.array([0.0, 1.0])
df.plot(kind='scatter', x='A', y='B', c=float_array, cmap='spring')

@slow
def test_plot_bar(self):
df = DataFrame(randn(6, 4),
Expand Down
8 changes: 5 additions & 3 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,16 +1403,18 @@ def _make_plot(self):
x, y, c, data = self.x, self.y, self.c, self.data
ax = self.axes[0]

c_is_column = com.is_hashable(c) and c in self.data.columns

# plot a colorbar only if a colormap is provided or necessary
cb = self.kwds.pop('colorbar', self.colormap or c in self.data.columns)
cb = self.kwds.pop('colorbar', self.colormap or c_is_column)

# pandas uses colormap, matplotlib uses cmap.
cmap = self.colormap or 'Greys'
cmap = plt.cm.get_cmap(cmap)

if c is None:
c_values = self.plt.rcParams['patch.facecolor']
elif c in self.data.columns:
elif c_is_column:
c_values = self.data[c].values
else:
c_values = c
Expand All @@ -1427,7 +1429,7 @@ def _make_plot(self):
img = ax.collections[0]
kws = dict(ax=ax)
if mpl_ge_1_3_1:
kws['label'] = c if c in self.data.columns else ''
kws['label'] = c if c_is_column else ''
self.fig.colorbar(img, **kws)

self._add_legend_handle(scatter, label)
Expand Down