From d3bcab758bd696e1c081458cbfc48e764a4f7bd0 Mon Sep 17 00:00:00 2001 From: Chang She Date: Wed, 11 Apr 2012 11:49:39 -0400 Subject: [PATCH 1/3] ENH: label sizes and rotations for histogram TST: test cases for both Series and DataFrame histogram --- pandas/tests/test_graphics.py | 24 ++++++++++++++++++ pandas/tools/plotting.py | 46 ++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index a0983ecf8578d..4fe40a6ebe4ad 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -142,6 +142,30 @@ def test_hist(self): df = DataFrame(np.random.randn(100, 6)) _check_plot_works(df.hist) + #make sure kwargs are handled + ser = df[0] + xf, yf = 20, 20 + xrot, yrot = 30, 30 + ax = ser.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30) + ytick = ax.get_yticklabels()[0] + xtick = ax.get_xticklabels()[0] + self.assertAlmostEqual(ytick.get_fontsize(), yf) + self.assertAlmostEqual(ytick.get_rotation(), yrot) + self.assertAlmostEqual(xtick.get_fontsize(), xf) + self.assertAlmostEqual(xtick.get_rotation(), xrot) + + xf, yf = 20, 20 + xrot, yrot = 30, 30 + axes = df.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30) + for i, ax in enumerate(axes.ravel()): + if i < len(df.columns): + ytick = ax.get_yticklabels()[0] + xtick = ax.get_xticklabels()[0] + self.assertAlmostEqual(ytick.get_fontsize(), yf) + self.assertAlmostEqual(ytick.get_rotation(), yrot) + self.assertAlmostEqual(xtick.get_fontsize(), xf) + self.assertAlmostEqual(xtick.get_rotation(), xrot) + @slow def test_scatter(self): df = DataFrame(np.random.randn(100, 4)) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index f80b69f2104ad..25808693d4596 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -702,15 +702,27 @@ def plot_group(group, ax): return fig -def hist_frame(data, grid=True, **kwds): +def hist_frame(data, grid=True, xlabelsize=None, xrot=None, + ylabelsize=None, yrot=None, **kwds): """ Draw Histogram the DataFrame's series using matplotlib / pylab. Parameters ---------- + grid : boolean, default True + Whether to show axis grid lines + xlabelsize : int, default None + If specified changes the x-axis label size + xrot : float, default None + rotation of x axis labels + ylabelsize : int, default None + If specified changes the y-axis label size + yrot : float, default None + rotation of y axis labels kwds : other plotting keyword arguments To be passed to hist function """ + import matplotlib.pyplot as plt n = len(data.columns) k = 1 while k ** 2 < n: @@ -723,10 +735,19 @@ def hist_frame(data, grid=True, **kwds): ax.set_title(col) ax.grid(grid) - return axes + if xlabelsize is not None: + plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(ax.get_xticklabels(), rotation=xrot) + if ylabelsize is not None: + plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(ax.get_yticklabels(), rotation=yrot) + return axes -def hist_series(self, ax=None, grid=True, **kwds): +def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None, + ylabelsize=None, yrot=None, **kwds): """ Draw histogram of the input series using matplotlib @@ -734,6 +755,16 @@ def hist_series(self, ax=None, grid=True, **kwds): ---------- ax : matplotlib axis object If not passed, uses gca() + grid : boolean, default True + Whether to show axis grid lines + xlabelsize : int, default None + If specified changes the x-axis label size + xrot : float, default None + rotation of x axis labels + ylabelsize : int, default None + If specified changes the y-axis label size + yrot : float, default None + rotation of y axis labels kwds : keywords To be passed to the actual plotting function @@ -752,6 +783,15 @@ def hist_series(self, ax=None, grid=True, **kwds): ax.hist(values, **kwds) ax.grid(grid) + if xlabelsize is not None: + plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(ax.get_xticklabels(), rotation=xrot) + if ylabelsize is not None: + plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(ax.get_yticklabels(), rotation=yrot) + return ax From 72983135dcd31b2c0786d4df27c766aa34bb9088 Mon Sep 17 00:00:00 2001 From: Chang She Date: Sun, 8 Apr 2012 07:52:02 -0400 Subject: [PATCH 2/3] add ax kwd to several functions and push ax into subplots so new subplot axes is generated on the ax's figure --- pandas/tools/plotting.py | 44 +++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 25808693d4596..b5a4b7f3e6d18 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -205,9 +205,15 @@ def _args_adjust(self): def _setup_subplots(self): if self.subplots: nrows, ncols = self._get_layout() - fig, axes = _subplots(nrows=nrows, ncols=ncols, - sharex=self.sharex, sharey=self.sharey, - figsize=self.figsize) + if self.ax is None: + fig, axes = _subplots(nrows=nrows, ncols=ncols, + sharex=self.sharex, sharey=self.sharey, + figsize=self.figsize) + else: + fig, axes = _subplots(nrows=nrows, ncols=ncols, + sharex=self.sharex, sharey=self.sharey, + figsize=self.figsize, ax=self.ax) + else: if self.ax is None: fig = self.plt.figure(figsize=self.figsize) @@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False, ------- ax_or_axes : matplotlib.AxesSubplot or list of them """ + kind = kind.lower().strip() if kind == 'line': klass = LinePlot elif kind in ('bar', 'barh'): klass = BarPlot + else: + raise ValueError('Invalid chart type given %s' % kind) plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, legend=legend, ax=ax, fontsize=fontsize, @@ -691,10 +700,11 @@ def plot_group(group, ax): ax.scatter(xvals, yvals) if by is not None: - fig = _grouped_plot(plot_group, data, by=by, figsize=figsize) + fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax) else: - fig = plt.figure() - ax = fig.add_subplot(111) + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111) plot_group(data, ax) ax.set_ylabel(str(y)) ax.set_xlabel(str(x)) @@ -703,7 +713,7 @@ def plot_group(group, ax): def hist_frame(data, grid=True, xlabelsize=None, xrot=None, - ylabelsize=None, yrot=None, **kwds): + ylabelsize=None, yrot=None, ax=None, **kwds): """ Draw Histogram the DataFrame's series using matplotlib / pylab. @@ -719,6 +729,7 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None, If specified changes the y-axis label size yrot : float, default None rotation of y axis labels + ax : matplotlib axes object, default None kwds : other plotting keyword arguments To be passed to hist function """ @@ -727,7 +738,7 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None, k = 1 while k ** 2 < n: k += 1 - _, axes = _subplots(nrows=k, ncols=k) + _, axes = _subplots(nrows=k, ncols=k, ax=ax) for i, col in enumerate(com._try_sort(data.columns)): ax = axes[i / k][i % k] @@ -797,7 +808,7 @@ def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None, def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, figsize=None, sharex=True, sharey=True, layout=None, - rot=0): + rot=0, ax=None): from pandas.core.frame import DataFrame # allow to specify mpl default with 'default' @@ -817,7 +828,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, # default size figsize = (10, 5) fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize, - sharex=sharex, sharey=sharey) + sharex=sharex, sharey=sharey, ax=ax) ravel_axes = [] for row in axes: @@ -834,7 +845,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, def _grouped_plot_by_column(plotf, data, columns=None, by=None, numeric_only=True, grid=False, - figsize=None): + figsize=None, ax=None): import matplotlib.pyplot as plt grouped = data.groupby(by) @@ -845,7 +856,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None, nrows, ncols = _get_layout(ngroups) fig, axes = _subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True, - figsize=figsize) + figsize=figsize, ax=ax) if isinstance(axes, plt.Axes): ravel_axes = [axes] @@ -890,7 +901,7 @@ def _get_layout(nplots): # copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0 def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, - subplot_kw=None, **fig_kw): + subplot_kw=None, ax=None, **fig_kw): """Create a figure with a set of subplots already made. This utility wrapper makes it convenient to create common layouts of @@ -930,6 +941,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, Dict with keywords passed to the figure() call. Note that all keywords not recognized above will be automatically included here. + ax : Matplotlib axis object, default None + Returns: fig, ax : tuple @@ -962,7 +975,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, if subplot_kw is None: subplot_kw = {} - fig = plt.figure(**fig_kw) + if ax is None: + fig = plt.figure(**fig_kw) + else: + fig = ax.get_figure() # Create empty object array to hold all axes. It's easiest to make it 1-d # so we can just append subplots upon creation, and then From 58116025242088dca566f6468b645d73d4f7b22f Mon Sep 17 00:00:00 2001 From: Chang She Date: Sun, 8 Apr 2012 09:20:53 -0400 Subject: [PATCH 3/3] BUG: return ax.get_figure() in scatter_plot if ax argument is not None --- pandas/tools/plotting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index b5a4b7f3e6d18..ea30787d0f172 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -705,6 +705,8 @@ def plot_group(group, ax): if ax is None: fig = plt.figure() ax = fig.add_subplot(111) + else: + fig = ax.get_figure() plot_group(data, ax) ax.set_ylabel(str(y)) ax.set_xlabel(str(x))