Skip to content

ENH: label sizes and rotations for histogram #1025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 11, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,30 @@ def test_hist(self):
df = DataFrame(np.random.randn(100, 6))
_check_plot_works(df.hist)

#make sure kwargs are handled
ser = df[0]
xf, yf = 20, 20
xrot, yrot = 30, 30
ax = ser.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
ytick = ax.get_yticklabels()[0]
xtick = ax.get_xticklabels()[0]
self.assertAlmostEqual(ytick.get_fontsize(), yf)
self.assertAlmostEqual(ytick.get_rotation(), yrot)
self.assertAlmostEqual(xtick.get_fontsize(), xf)
self.assertAlmostEqual(xtick.get_rotation(), xrot)

xf, yf = 20, 20
xrot, yrot = 30, 30
axes = df.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
for i, ax in enumerate(axes.ravel()):
if i < len(df.columns):
ytick = ax.get_yticklabels()[0]
xtick = ax.get_xticklabels()[0]
self.assertAlmostEqual(ytick.get_fontsize(), yf)
self.assertAlmostEqual(ytick.get_rotation(), yrot)
self.assertAlmostEqual(xtick.get_fontsize(), xf)
self.assertAlmostEqual(xtick.get_rotation(), xrot)

@slow
def test_scatter(self):
df = DataFrame(np.random.randn(100, 4))
Expand Down
90 changes: 74 additions & 16 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,15 @@ def _args_adjust(self):
def _setup_subplots(self):
if self.subplots:
nrows, ncols = self._get_layout()
fig, axes = _subplots(nrows=nrows, ncols=ncols,
sharex=self.sharex, sharey=self.sharey,
figsize=self.figsize)
if self.ax is None:
fig, axes = _subplots(nrows=nrows, ncols=ncols,
sharex=self.sharex, sharey=self.sharey,
figsize=self.figsize)
else:
fig, axes = _subplots(nrows=nrows, ncols=ncols,
sharex=self.sharex, sharey=self.sharey,
figsize=self.figsize, ax=self.ax)

else:
if self.ax is None:
fig = self.plt.figure(figsize=self.figsize)
Expand Down Expand Up @@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
-------
ax_or_axes : matplotlib.AxesSubplot or list of them
"""
kind = kind.lower().strip()
if kind == 'line':
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
else:
raise ValueError('Invalid chart type given %s' % kind)

plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, fontsize=fontsize,
Expand Down Expand Up @@ -691,49 +700,84 @@ def plot_group(group, ax):
ax.scatter(xvals, yvals)

if by is not None:
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize)
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
else:
fig = plt.figure()
ax = fig.add_subplot(111)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
plot_group(data, ax)
ax.set_ylabel(str(y))
ax.set_xlabel(str(x))

return fig


def hist_frame(data, grid=True, **kwds):
def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
ylabelsize=None, yrot=None, ax=None, **kwds):
"""
Draw Histogram the DataFrame's series using matplotlib / pylab.

Parameters
----------
grid : boolean, default True
Whether to show axis grid lines
xlabelsize : int, default None
If specified changes the x-axis label size
xrot : float, default None
rotation of x axis labels
ylabelsize : int, default None
If specified changes the y-axis label size
yrot : float, default None
rotation of y axis labels
ax : matplotlib axes object, default None
kwds : other plotting keyword arguments
To be passed to hist function
"""
import matplotlib.pyplot as plt
n = len(data.columns)
k = 1
while k ** 2 < n:
k += 1
_, axes = _subplots(nrows=k, ncols=k)
_, axes = _subplots(nrows=k, ncols=k, ax=ax)

for i, col in enumerate(com._try_sort(data.columns)):
ax = axes[i / k][i % k]
ax.hist(data[col].dropna().values, **kwds)
ax.set_title(col)
ax.grid(grid)

return axes
if xlabelsize is not None:
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(ax.get_xticklabels(), rotation=xrot)
if ylabelsize is not None:
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(ax.get_yticklabels(), rotation=yrot)

return axes

def hist_series(self, ax=None, grid=True, **kwds):
def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None,
ylabelsize=None, yrot=None, **kwds):
"""
Draw histogram of the input series using matplotlib

Parameters
----------
ax : matplotlib axis object
If not passed, uses gca()
grid : boolean, default True
Whether to show axis grid lines
xlabelsize : int, default None
If specified changes the x-axis label size
xrot : float, default None
rotation of x axis labels
ylabelsize : int, default None
If specified changes the y-axis label size
yrot : float, default None
rotation of y axis labels
kwds : keywords
To be passed to the actual plotting function

Expand All @@ -752,12 +796,21 @@ def hist_series(self, ax=None, grid=True, **kwds):
ax.hist(values, **kwds)
ax.grid(grid)

if xlabelsize is not None:
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(ax.get_xticklabels(), rotation=xrot)
if ylabelsize is not None:
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(ax.get_yticklabels(), rotation=yrot)

return ax


def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0):
rot=0, ax=None):
from pandas.core.frame import DataFrame

# allow to specify mpl default with 'default'
Expand All @@ -777,7 +830,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
# default size
figsize = (10, 5)
fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize,
sharex=sharex, sharey=sharey)
sharex=sharex, sharey=sharey, ax=ax)

ravel_axes = []
for row in axes:
Expand All @@ -794,7 +847,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,

def _grouped_plot_by_column(plotf, data, columns=None, by=None,
numeric_only=True, grid=False,
figsize=None):
figsize=None, ax=None):
import matplotlib.pyplot as plt

grouped = data.groupby(by)
Expand All @@ -805,7 +858,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
nrows, ncols = _get_layout(ngroups)
fig, axes = _subplots(nrows=nrows, ncols=ncols,
sharex=True, sharey=True,
figsize=figsize)
figsize=figsize, ax=ax)

if isinstance(axes, plt.Axes):
ravel_axes = [axes]
Expand Down Expand Up @@ -850,7 +903,7 @@ def _get_layout(nplots):
# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0

def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
subplot_kw=None, **fig_kw):
subplot_kw=None, ax=None, **fig_kw):
"""Create a figure with a set of subplots already made.

This utility wrapper makes it convenient to create common layouts of
Expand Down Expand Up @@ -890,6 +943,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
Dict with keywords passed to the figure() call. Note that all keywords
not recognized above will be automatically included here.

ax : Matplotlib axis object, default None

Returns:

fig, ax : tuple
Expand Down Expand Up @@ -922,7 +977,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
if subplot_kw is None:
subplot_kw = {}

fig = plt.figure(**fig_kw)
if ax is None:
fig = plt.figure(**fig_kw)
else:
fig = ax.get_figure()

# Create empty object array to hold all axes. It's easiest to make it 1-d
# so we can just append subplots upon creation, and then
Expand Down