Skip to content

Commit 6bf2ef9

Browse files
committed
Merge pull request #1025 from changhiskhan/histogram-options
ENH: label sizes and rotations for histogram, close #1012
2 parents 187de26 + 5811602 commit 6bf2ef9

File tree

2 files changed

+98
-16
lines changed

2 files changed

+98
-16
lines changed

pandas/tests/test_graphics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,30 @@ def test_hist(self):
142142
df = DataFrame(np.random.randn(100, 6))
143143
_check_plot_works(df.hist)
144144

145+
#make sure kwargs are handled
146+
ser = df[0]
147+
xf, yf = 20, 20
148+
xrot, yrot = 30, 30
149+
ax = ser.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
150+
ytick = ax.get_yticklabels()[0]
151+
xtick = ax.get_xticklabels()[0]
152+
self.assertAlmostEqual(ytick.get_fontsize(), yf)
153+
self.assertAlmostEqual(ytick.get_rotation(), yrot)
154+
self.assertAlmostEqual(xtick.get_fontsize(), xf)
155+
self.assertAlmostEqual(xtick.get_rotation(), xrot)
156+
157+
xf, yf = 20, 20
158+
xrot, yrot = 30, 30
159+
axes = df.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30)
160+
for i, ax in enumerate(axes.ravel()):
161+
if i < len(df.columns):
162+
ytick = ax.get_yticklabels()[0]
163+
xtick = ax.get_xticklabels()[0]
164+
self.assertAlmostEqual(ytick.get_fontsize(), yf)
165+
self.assertAlmostEqual(ytick.get_rotation(), yrot)
166+
self.assertAlmostEqual(xtick.get_fontsize(), xf)
167+
self.assertAlmostEqual(xtick.get_rotation(), xrot)
168+
145169
@slow
146170
def test_scatter(self):
147171
df = DataFrame(np.random.randn(100, 4))

pandas/tools/plotting.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,15 @@ def _args_adjust(self):
205205
def _setup_subplots(self):
206206
if self.subplots:
207207
nrows, ncols = self._get_layout()
208-
fig, axes = _subplots(nrows=nrows, ncols=ncols,
209-
sharex=self.sharex, sharey=self.sharey,
210-
figsize=self.figsize)
208+
if self.ax is None:
209+
fig, axes = _subplots(nrows=nrows, ncols=ncols,
210+
sharex=self.sharex, sharey=self.sharey,
211+
figsize=self.figsize)
212+
else:
213+
fig, axes = _subplots(nrows=nrows, ncols=ncols,
214+
sharex=self.sharex, sharey=self.sharey,
215+
figsize=self.figsize, ax=self.ax)
216+
211217
else:
212218
if self.ax is None:
213219
fig = self.plt.figure(figsize=self.figsize)
@@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
509515
-------
510516
ax_or_axes : matplotlib.AxesSubplot or list of them
511517
"""
518+
kind = kind.lower().strip()
512519
if kind == 'line':
513520
klass = LinePlot
514521
elif kind in ('bar', 'barh'):
515522
klass = BarPlot
523+
else:
524+
raise ValueError('Invalid chart type given %s' % kind)
516525

517526
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
518527
legend=legend, ax=ax, fontsize=fontsize,
@@ -691,49 +700,84 @@ def plot_group(group, ax):
691700
ax.scatter(xvals, yvals)
692701

693702
if by is not None:
694-
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize)
703+
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
695704
else:
696-
fig = plt.figure()
697-
ax = fig.add_subplot(111)
705+
if ax is None:
706+
fig = plt.figure()
707+
ax = fig.add_subplot(111)
708+
else:
709+
fig = ax.get_figure()
698710
plot_group(data, ax)
699711
ax.set_ylabel(str(y))
700712
ax.set_xlabel(str(x))
701713

702714
return fig
703715

704716

705-
def hist_frame(data, grid=True, **kwds):
717+
def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
718+
ylabelsize=None, yrot=None, ax=None, **kwds):
706719
"""
707720
Draw Histogram the DataFrame's series using matplotlib / pylab.
708721
709722
Parameters
710723
----------
724+
grid : boolean, default True
725+
Whether to show axis grid lines
726+
xlabelsize : int, default None
727+
If specified changes the x-axis label size
728+
xrot : float, default None
729+
rotation of x axis labels
730+
ylabelsize : int, default None
731+
If specified changes the y-axis label size
732+
yrot : float, default None
733+
rotation of y axis labels
734+
ax : matplotlib axes object, default None
711735
kwds : other plotting keyword arguments
712736
To be passed to hist function
713737
"""
738+
import matplotlib.pyplot as plt
714739
n = len(data.columns)
715740
k = 1
716741
while k ** 2 < n:
717742
k += 1
718-
_, axes = _subplots(nrows=k, ncols=k)
743+
_, axes = _subplots(nrows=k, ncols=k, ax=ax)
719744

720745
for i, col in enumerate(com._try_sort(data.columns)):
721746
ax = axes[i / k][i % k]
722747
ax.hist(data[col].dropna().values, **kwds)
723748
ax.set_title(col)
724749
ax.grid(grid)
725750

726-
return axes
751+
if xlabelsize is not None:
752+
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
753+
if xrot is not None:
754+
plt.setp(ax.get_xticklabels(), rotation=xrot)
755+
if ylabelsize is not None:
756+
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
757+
if yrot is not None:
758+
plt.setp(ax.get_yticklabels(), rotation=yrot)
727759

760+
return axes
728761

729-
def hist_series(self, ax=None, grid=True, **kwds):
762+
def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None,
763+
ylabelsize=None, yrot=None, **kwds):
730764
"""
731765
Draw histogram of the input series using matplotlib
732766
733767
Parameters
734768
----------
735769
ax : matplotlib axis object
736770
If not passed, uses gca()
771+
grid : boolean, default True
772+
Whether to show axis grid lines
773+
xlabelsize : int, default None
774+
If specified changes the x-axis label size
775+
xrot : float, default None
776+
rotation of x axis labels
777+
ylabelsize : int, default None
778+
If specified changes the y-axis label size
779+
yrot : float, default None
780+
rotation of y axis labels
737781
kwds : keywords
738782
To be passed to the actual plotting function
739783
@@ -752,12 +796,21 @@ def hist_series(self, ax=None, grid=True, **kwds):
752796
ax.hist(values, **kwds)
753797
ax.grid(grid)
754798

799+
if xlabelsize is not None:
800+
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
801+
if xrot is not None:
802+
plt.setp(ax.get_xticklabels(), rotation=xrot)
803+
if ylabelsize is not None:
804+
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
805+
if yrot is not None:
806+
plt.setp(ax.get_yticklabels(), rotation=yrot)
807+
755808
return ax
756809

757810

758811
def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
759812
figsize=None, sharex=True, sharey=True, layout=None,
760-
rot=0):
813+
rot=0, ax=None):
761814
from pandas.core.frame import DataFrame
762815

763816
# allow to specify mpl default with 'default'
@@ -777,7 +830,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
777830
# default size
778831
figsize = (10, 5)
779832
fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize,
780-
sharex=sharex, sharey=sharey)
833+
sharex=sharex, sharey=sharey, ax=ax)
781834

782835
ravel_axes = []
783836
for row in axes:
@@ -794,7 +847,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
794847

795848
def _grouped_plot_by_column(plotf, data, columns=None, by=None,
796849
numeric_only=True, grid=False,
797-
figsize=None):
850+
figsize=None, ax=None):
798851
import matplotlib.pyplot as plt
799852

800853
grouped = data.groupby(by)
@@ -805,7 +858,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
805858
nrows, ncols = _get_layout(ngroups)
806859
fig, axes = _subplots(nrows=nrows, ncols=ncols,
807860
sharex=True, sharey=True,
808-
figsize=figsize)
861+
figsize=figsize, ax=ax)
809862

810863
if isinstance(axes, plt.Axes):
811864
ravel_axes = [axes]
@@ -850,7 +903,7 @@ def _get_layout(nplots):
850903
# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
851904

852905
def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
853-
subplot_kw=None, **fig_kw):
906+
subplot_kw=None, ax=None, **fig_kw):
854907
"""Create a figure with a set of subplots already made.
855908
856909
This utility wrapper makes it convenient to create common layouts of
@@ -890,6 +943,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
890943
Dict with keywords passed to the figure() call. Note that all keywords
891944
not recognized above will be automatically included here.
892945
946+
ax : Matplotlib axis object, default None
947+
893948
Returns:
894949
895950
fig, ax : tuple
@@ -922,7 +977,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
922977
if subplot_kw is None:
923978
subplot_kw = {}
924979

925-
fig = plt.figure(**fig_kw)
980+
if ax is None:
981+
fig = plt.figure(**fig_kw)
982+
else:
983+
fig = ax.get_figure()
926984

927985
# Create empty object array to hold all axes. It's easiest to make it 1-d
928986
# so we can just append subplots upon creation, and then

0 commit comments

Comments
 (0)