Skip to content

Commit faee760

Browse files
committed
Merge pull request #6876 from sinhrks/cln_plot
CLN: simplify series plotting
2 parents b9efa31 + c76d464 commit faee760

File tree

2 files changed

+57
-82
lines changed

2 files changed

+57
-82
lines changed

pandas/tests/test_graphics.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ def test_plot(self):
5555
_check_plot_works(self.ts.plot, style='.', loglog=True)
5656
_check_plot_works(self.ts[:10].plot, kind='bar')
5757
_check_plot_works(self.iseries.plot)
58-
_check_plot_works(self.series[:5].plot, kind='bar')
59-
_check_plot_works(self.series[:5].plot, kind='line')
60-
_check_plot_works(self.series[:5].plot, kind='barh')
58+
59+
for kind in plotting._common_kinds:
60+
_check_plot_works(self.series[:5].plot, kind=kind)
61+
6162
_check_plot_works(self.series[:10].plot, kind='barh')
6263
_check_plot_works(Series(randn(10)).plot, kind='bar', color='black')
6364

@@ -250,25 +251,19 @@ def test_bootstrap_plot(self):
250251

251252
def test_invalid_plot_data(self):
252253
s = Series(list('abcd'))
253-
kinds = 'line', 'bar', 'barh', 'kde', 'density'
254-
255-
for kind in kinds:
254+
for kind in plotting._common_kinds:
256255
with tm.assertRaises(TypeError):
257256
s.plot(kind=kind)
258257

259258
@slow
260259
def test_valid_object_plot(self):
261260
s = Series(lrange(10), dtype=object)
262-
kinds = 'line', 'bar', 'barh', 'kde', 'density'
263-
264-
for kind in kinds:
261+
for kind in plotting._common_kinds:
265262
_check_plot_works(s.plot, kind=kind)
266263

267264
def test_partially_invalid_plot_data(self):
268265
s = Series(['a', 'b', 1.0, 2])
269-
kinds = 'line', 'bar', 'barh', 'kde', 'density'
270-
271-
for kind in kinds:
266+
for kind in plotting._common_kinds:
272267
with tm.assertRaises(TypeError):
273268
s.plot(kind=kind)
274269

@@ -1247,19 +1242,17 @@ def test_unordered_ts(self):
12471242
assert_array_equal(ydata, np.array([1.0, 2.0, 3.0]))
12481243

12491244
def test_all_invalid_plot_data(self):
1250-
kinds = 'line', 'bar', 'barh', 'kde', 'density'
12511245
df = DataFrame(list('abcd'))
1252-
for kind in kinds:
1246+
for kind in plotting._common_kinds:
12531247
with tm.assertRaises(TypeError):
12541248
df.plot(kind=kind)
12551249

12561250
@slow
12571251
def test_partially_invalid_plot_data(self):
12581252
with tm.RNGContext(42):
1259-
kinds = 'line', 'bar', 'barh', 'kde', 'density'
12601253
df = DataFrame(randn(10, 2), dtype=object)
12611254
df[np.random.rand(df.shape[0]) > 0.5] = 'a'
1262-
for kind in kinds:
1255+
for kind in plotting._common_kinds:
12631256
with tm.assertRaises(TypeError):
12641257
df.plot(kind=kind)
12651258

pandas/tools/plotting.py

+48-66
Original file line numberDiff line numberDiff line change
@@ -887,25 +887,31 @@ def _validate_color_args(self):
887887
" use one or the other or pass 'style' "
888888
"without a color symbol")
889889

890-
def _iter_data(self):
891-
from pandas.core.frame import DataFrame
892-
if isinstance(self.data, (Series, np.ndarray)):
893-
yield self.label, np.asarray(self.data)
894-
elif isinstance(self.data, DataFrame):
895-
df = self.data
890+
def _iter_data(self, data=None, keep_index=False):
891+
if data is None:
892+
data = self.data
896893

894+
from pandas.core.frame import DataFrame
895+
if isinstance(data, (Series, np.ndarray)):
896+
if keep_index is True:
897+
yield self.label, data
898+
else:
899+
yield self.label, np.asarray(data)
900+
elif isinstance(data, DataFrame):
897901
if self.sort_columns:
898-
columns = com._try_sort(df.columns)
902+
columns = com._try_sort(data.columns)
899903
else:
900-
columns = df.columns
904+
columns = data.columns
901905

902906
for col in columns:
903907
# # is this right?
904908
# empty = df[col].count() == 0
905909
# values = df[col].values if not empty else np.zeros(len(df))
906910

907-
values = df[col].values
908-
yield col, values
911+
if keep_index is True:
912+
yield col, data[col]
913+
else:
914+
yield col, data[col].values
909915

910916
@property
911917
def nseries(self):
@@ -1593,38 +1599,26 @@ def _plot(data, col_num, ax, label, style, **kwds):
15931599

15941600
self._add_legend_handle(newlines[0], label, index=col_num)
15951601

1596-
if isinstance(data, Series):
1597-
ax = self._get_ax(0) # self.axes[0]
1598-
style = self.style or ''
1599-
label = com.pprint_thing(self.label)
1602+
it = self._iter_data(data=data, keep_index=True)
1603+
for i, (label, y) in enumerate(it):
1604+
ax = self._get_ax(i)
1605+
style = self._get_style(i, label)
16001606
kwds = self.kwds.copy()
1601-
self._maybe_add_color(colors, kwds, style, 0)
1602-
1603-
if 'yerr' in kwds:
1604-
kwds['yerr'] = kwds['yerr'][0]
16051607

1606-
_plot(data, 0, ax, label, self.style, **kwds)
1607-
1608-
else:
1609-
for i, col in enumerate(data.columns):
1610-
label = com.pprint_thing(col)
1611-
ax = self._get_ax(i)
1612-
style = self._get_style(i, col)
1613-
kwds = self.kwds.copy()
1614-
1615-
self._maybe_add_color(colors, kwds, style, i)
1608+
self._maybe_add_color(colors, kwds, style, i)
16161609

1617-
# key-matched DataFrame of errors
1618-
if 'yerr' in kwds:
1619-
yerr = kwds['yerr']
1620-
if isinstance(yerr, (DataFrame, dict)):
1621-
if col in yerr.keys():
1622-
kwds['yerr'] = yerr[col]
1623-
else: del kwds['yerr']
1624-
else:
1625-
kwds['yerr'] = yerr[i]
1610+
# key-matched DataFrame of errors
1611+
if 'yerr' in kwds:
1612+
yerr = kwds['yerr']
1613+
if isinstance(yerr, (DataFrame, dict)):
1614+
if label in yerr.keys():
1615+
kwds['yerr'] = yerr[label]
1616+
else: del kwds['yerr']
1617+
else:
1618+
kwds['yerr'] = yerr[i]
16261619

1627-
_plot(data[col], i, ax, label, style, **kwds)
1620+
label = com.pprint_thing(label)
1621+
_plot(y, i, ax, label, style, **kwds)
16281622

16291623
def _maybe_convert_index(self, data):
16301624
# tsplot converts automatically, but don't want to convert index
@@ -1828,6 +1822,16 @@ class BoxPlot(MPLPlot):
18281822
class HistPlot(MPLPlot):
18291823
pass
18301824

1825+
# kinds supported by both dataframe and series
1826+
_common_kinds = ['line', 'bar', 'barh', 'kde', 'density']
1827+
# kinds supported by dataframe
1828+
_dataframe_kinds = ['scatter', 'hexbin']
1829+
_all_kinds = _common_kinds + _dataframe_kinds
1830+
1831+
_plot_klass = {'line': LinePlot, 'bar': BarPlot, 'barh': BarPlot,
1832+
'kde': KdePlot,
1833+
'scatter': ScatterPlot, 'hexbin': HexBinPlot}
1834+
18311835

18321836
def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
18331837
sharey=False, use_index=True, figsize=None, grid=None,
@@ -1921,21 +1925,14 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
19211925
is a function of one argument that reduces all the values in a bin to
19221926
a single number (e.g. `mean`, `max`, `sum`, `std`).
19231927
"""
1928+
19241929
kind = _get_standard_kind(kind.lower().strip())
1925-
if kind == 'line':
1926-
klass = LinePlot
1927-
elif kind in ('bar', 'barh'):
1928-
klass = BarPlot
1929-
elif kind == 'kde':
1930-
klass = KdePlot
1931-
elif kind == 'scatter':
1932-
klass = ScatterPlot
1933-
elif kind == 'hexbin':
1934-
klass = HexBinPlot
1930+
if kind in _dataframe_kinds or kind in _common_kinds:
1931+
klass = _plot_klass[kind]
19351932
else:
19361933
raise ValueError('Invalid chart type given %s' % kind)
19371934

1938-
if kind == 'scatter':
1935+
if kind in _dataframe_kinds:
19391936
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
19401937
rot=rot,legend=legend, ax=ax, style=style,
19411938
fontsize=fontsize, use_index=use_index, sharex=sharex,
@@ -1944,16 +1941,6 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
19441941
figsize=figsize, logx=logx, logy=logy,
19451942
sort_columns=sort_columns, secondary_y=secondary_y,
19461943
**kwds)
1947-
elif kind == 'hexbin':
1948-
C = kwds.pop('C', None) # remove from kwargs so we can set default
1949-
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
1950-
rot=rot,legend=legend, ax=ax, style=style,
1951-
fontsize=fontsize, use_index=use_index, sharex=sharex,
1952-
sharey=sharey, xticks=xticks, yticks=yticks,
1953-
xlim=xlim, ylim=ylim, title=title, grid=grid,
1954-
figsize=figsize, logx=logx, logy=logy,
1955-
sort_columns=sort_columns, secondary_y=secondary_y,
1956-
C=C, **kwds)
19571944
else:
19581945
if x is not None:
19591946
if com.is_integer(x) and not frame.columns.holds_integer():
@@ -2051,14 +2038,9 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None,
20512038
See matplotlib documentation online for more on this subject
20522039
"""
20532040

2054-
from pandas import DataFrame
20552041
kind = _get_standard_kind(kind.lower().strip())
2056-
if kind == 'line':
2057-
klass = LinePlot
2058-
elif kind in ('bar', 'barh'):
2059-
klass = BarPlot
2060-
elif kind == 'kde':
2061-
klass = KdePlot
2042+
if kind in _common_kinds:
2043+
klass = _plot_klass[kind]
20622044
else:
20632045
raise ValueError('Invalid chart type given %s' % kind)
20642046

0 commit comments

Comments
 (0)