Skip to content

Commit 44d3c54

Browse files
committed
fixup! ENH: .squeeze accepts axis parameter
1 parent cc018c9 commit 44d3c54

File tree

2 files changed

+7
-23
lines changed

2 files changed

+7
-23
lines changed

pandas/core/generic.py

+5-21
Original file line numberDiff line numberDiff line change
@@ -537,32 +537,16 @@ def squeeze(self, axis=None):
537537
538538
Parameters
539539
----------
540-
axis : None or int or tuple of ints, optional
541-
Selects a subset of the single-dimensional entries in the
542-
shape. If an axis is selected with shape entry greater than
543-
one, an error is raised.
540+
axis : None or int, optional
541+
The axis to squeeze if 1-sized.
544542
545543
.. versionadded:: 0.20.0
546544
"""
547-
if axis is None:
548-
axis = tuple(range(len(self.axes)))
549-
else:
550-
if not is_list_like(axis):
551-
axis = (axis,)
552-
if not all(is_integer(ax) for ax in axis):
553-
raise TypeError('an integer is required for the axis')
554-
n_axes = len(self.axes)
555-
for ax in axis:
556-
if ax < -n_axes or ax >= n_axes:
557-
raise ValueError("'axis' entry {0} is out of bounds "
558-
"[-{1}, {1})".format(ax, n_axes))
559-
if any(len(self.axes[ax]) != 1 for ax in axis):
560-
raise ValueError('cannot select an axis to squeeze out which '
561-
'has size not equal to one')
562-
545+
axis = (self._AXIS_NAMES if axis is None else
546+
(self._get_axis_number(axis),))
563547
try:
564548
return self.iloc[
565-
tuple([0 if len(a) == 1 and i in axis else slice(None)
549+
tuple([0 if i in axis and len(a) == 1 else slice(None)
566550
for i, a in enumerate(self.axes)])]
567551
except:
568552
return self

pandas/tests/test_generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1777,10 +1777,10 @@ def test_squeeze(self):
17771777
tm.assert_series_equal(df.squeeze(axis=1), df.iloc[:, 0])
17781778
tm.assert_equal(df.squeeze(), df.iloc[0, 0])
17791779
tm.assertRaises(ValueError, df.squeeze, axis=2)
1780-
tm.assertRaises(TypeError, df.squeeze, axis='x')
1780+
tm.assertRaises(ValueError, df.squeeze, axis='x')
17811781

17821782
df = tm.makeTimeDataFrame(3)
1783-
tm.assertRaises(ValueError, df.squeeze, axis=0)
1783+
tm.assert_frame_equal(df.squeeze(axis=0), df)
17841784

17851785
def test_numpy_squeeze(self):
17861786
s = tm.makeFloatSeries()

0 commit comments

Comments
 (0)