Skip to content

Commit 11d274f

Browse files
stangiralaTomAugspurger
authored andcommitted
BUG: Categorical scatter plot has KeyError pandas-dev#16199 (pandas-dev#16208)
* BUG: Categorical scatter plot has KeyError pandas-dev#16199 Appropriately handles categorical data for dataframe scatter plots which currently raises KeyError for categorical data * Add to whatsnew
1 parent b72519e commit 11d274f

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

doc/source/whatsnew/v0.20.3.txt

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Performance Improvements
3636

3737
Bug Fixes
3838
~~~~~~~~~
39+
- Fixed issue with dataframe scatter plot for categorical data that reports incorrect column key not found when categorical data is used for plotting (:issue:`16199`)
3940

4041

4142

pandas/plotting/_core.py

+5
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,11 @@ def __init__(self, data, x, y, **kwargs):
778778
x = self.data.columns[x]
779779
if is_integer(y) and not self.data.columns.holds_integer():
780780
y = self.data.columns[y]
781+
if len(self.data[x]._get_numeric_data()) == 0:
782+
raise ValueError(self._kind + ' requires x column to be numeric')
783+
if len(self.data[y]._get_numeric_data()) == 0:
784+
raise ValueError(self._kind + ' requires y column to be numeric')
785+
781786
self.x = x
782787
self.y = y
783788

pandas/tests/plotting/test_frame.py

+18
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,24 @@ def test_plot_scatter(self):
915915
axes = df.plot(x='x', y='y', kind='scatter', subplots=True)
916916
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
917917

918+
@slow
919+
def test_plot_scatter_with_categorical_data(self):
920+
# GH 16199
921+
df = pd.DataFrame({'x': [1, 2, 3, 4],
922+
'y': pd.Categorical(['a', 'b', 'a', 'c'])})
923+
924+
with pytest.raises(ValueError) as ve:
925+
df.plot(x='x', y='y', kind='scatter')
926+
ve.match('requires y column to be numeric')
927+
928+
with pytest.raises(ValueError) as ve:
929+
df.plot(x='y', y='x', kind='scatter')
930+
ve.match('requires x column to be numeric')
931+
932+
with pytest.raises(ValueError) as ve:
933+
df.plot(x='y', y='y', kind='scatter')
934+
ve.match('requires x column to be numeric')
935+
918936
@slow
919937
def test_plot_scatter_with_c(self):
920938
df = DataFrame(randn(6, 4),

0 commit comments

Comments
 (0)