Skip to content

Commit cc018c9

Browse files
committed
ENH: .squeeze accepts axis parameter
1 parent 542c916 commit cc018c9

File tree

4 files changed

+44
-16
lines changed

4 files changed

+44
-16
lines changed

doc/source/whatsnew/v0.20.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ Other enhancements
150150
- ``Series/DataFrame.resample.asfreq`` have gained a ``fill_value`` parameter, to fill missing values during resampling (:issue:`3715`).
151151
- ``pandas.tools.hashing`` has gained a ``hash_tuples`` routine, and ``hash_pandas_object`` has gained the ability to hash a ``MultiIndex`` (:issue:`15224`)
152152

153+
- ``Series/DataFrame.squeeze()`` have gained support for ``axis`` parameter. (:issue:15339``)
154+
153155
.. _ISO 8601 duration: https://en.wikipedia.org/wiki/ISO_8601#Durations
154156

155157
.. _whatsnew_0200.api_breaking:

pandas/compat/numpy/function.py

-7
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,6 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
214214
validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS,
215215
method='kwargs')
216216

217-
# Currently, numpy (v1.11) has backwards compatibility checks
218-
# in place so that this 'kwargs' parameter is technically
219-
# unnecessary, but in the long-run, this will be needed.
220-
SQUEEZE_DEFAULTS = dict(axis=None)
221-
validate_squeeze = CompatValidator(SQUEEZE_DEFAULTS, fname='squeeze',
222-
method='kwargs')
223-
224217
TAKE_DEFAULTS = OrderedDict()
225218
TAKE_DEFAULTS['out'] = None
226219
TAKE_DEFAULTS['mode'] = 'raise'

pandas/core/generic.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -532,13 +532,38 @@ def pop(self, item):
532532

533533
return result
534534

535-
def squeeze(self, **kwargs):
536-
"""Squeeze length 1 dimensions."""
537-
nv.validate_squeeze(tuple(), kwargs)
535+
def squeeze(self, axis=None):
536+
"""Squeeze length 1 dimensions.
537+
538+
Parameters
539+
----------
540+
axis : None or int or tuple of ints, optional
541+
Selects a subset of the single-dimensional entries in the
542+
shape. If an axis is selected with shape entry greater than
543+
one, an error is raised.
544+
545+
.. versionadded:: 0.20.0
546+
"""
547+
if axis is None:
548+
axis = tuple(range(len(self.axes)))
549+
else:
550+
if not is_list_like(axis):
551+
axis = (axis,)
552+
if not all(is_integer(ax) for ax in axis):
553+
raise TypeError('an integer is required for the axis')
554+
n_axes = len(self.axes)
555+
for ax in axis:
556+
if ax < -n_axes or ax >= n_axes:
557+
raise ValueError("'axis' entry {0} is out of bounds "
558+
"[-{1}, {1})".format(ax, n_axes))
559+
if any(len(self.axes[ax]) != 1 for ax in axis):
560+
raise ValueError('cannot select an axis to squeeze out which '
561+
'has size not equal to one')
538562

539563
try:
540-
return self.iloc[tuple([0 if len(a) == 1 else slice(None)
541-
for a in self.axes])]
564+
return self.iloc[
565+
tuple([0 if len(a) == 1 and i in axis else slice(None)
566+
for i, a in enumerate(self.axes)])]
542567
except:
543568
return self
544569

pandas/tests/test_generic.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1770,17 +1770,25 @@ def test_squeeze(self):
17701770
[tm.assert_series_equal(empty_series, higher_dim.squeeze())
17711771
for higher_dim in [empty_series, empty_frame, empty_panel]]
17721772

1773+
# axis argument
1774+
df = tm.makeTimeDataFrame(nper=1).iloc[:, :1]
1775+
tm.assert_equal(df.shape, (1, 1))
1776+
tm.assert_series_equal(df.squeeze(axis=0), df.iloc[0])
1777+
tm.assert_series_equal(df.squeeze(axis=1), df.iloc[:, 0])
1778+
tm.assert_equal(df.squeeze(), df.iloc[0, 0])
1779+
tm.assertRaises(ValueError, df.squeeze, axis=2)
1780+
tm.assertRaises(TypeError, df.squeeze, axis='x')
1781+
1782+
df = tm.makeTimeDataFrame(3)
1783+
tm.assertRaises(ValueError, df.squeeze, axis=0)
1784+
17731785
def test_numpy_squeeze(self):
17741786
s = tm.makeFloatSeries()
17751787
tm.assert_series_equal(np.squeeze(s), s)
17761788

17771789
df = tm.makeTimeDataFrame().reindex(columns=['A'])
17781790
tm.assert_series_equal(np.squeeze(df), df['A'])
17791791

1780-
msg = "the 'axis' parameter is not supported"
1781-
tm.assertRaisesRegexp(ValueError, msg,
1782-
np.squeeze, s, axis=0)
1783-
17841792
def test_transpose(self):
17851793
msg = (r"transpose\(\) got multiple values for "
17861794
r"keyword argument 'axes'")

0 commit comments

Comments
 (0)