Skip to content

ENH: .squeeze accepts axis parameter #15335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ Other enhancements
- ``Series/DataFrame.resample.asfreq`` have gained a ``fill_value`` parameter, to fill missing values during resampling (:issue:`3715`).
- ``pandas.tools.hashing`` has gained a ``hash_tuples`` routine, and ``hash_pandas_object`` has gained the ability to hash a ``MultiIndex`` (:issue:`15224`)

- ``Series/DataFrame.squeeze()`` have gained support for ``axis`` parameter. (:issue:15339``)

.. _ISO 8601 duration: https://en.wikipedia.org/wiki/ISO_8601#Durations

.. _whatsnew_0200.api_breaking:
Expand Down
7 changes: 0 additions & 7 deletions pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,6 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS,
method='kwargs')

# Currently, numpy (v1.11) has backwards compatibility checks
# in place so that this 'kwargs' parameter is technically
# unnecessary, but in the long-run, this will be needed.
SQUEEZE_DEFAULTS = dict(axis=None)
validate_squeeze = CompatValidator(SQUEEZE_DEFAULTS, fname='squeeze',
method='kwargs')

TAKE_DEFAULTS = OrderedDict()
TAKE_DEFAULTS['out'] = None
TAKE_DEFAULTS['mode'] = 'raise'
Expand Down
35 changes: 30 additions & 5 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,13 +532,38 @@ def pop(self, item):

return result

def squeeze(self, **kwargs):
"""Squeeze length 1 dimensions."""
nv.validate_squeeze(tuple(), kwargs)
def squeeze(self, axis=None):
"""Squeeze length 1 dimensions.

Parameters
----------
axis : None or int or tuple of ints, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you don't need virtually any of this validation, simply

axis = self._get_axis_number(axis), this can only accept a single axis (or None by-default)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like the pushed change?

Selects a subset of the single-dimensional entries in the
shape. If an axis is selected with shape entry greater than
one, an error is raised.

.. versionadded:: 0.20.0
"""
if axis is None:
axis = tuple(range(len(self.axes)))
else:
if not is_list_like(axis):
axis = (axis,)
if not all(is_integer(ax) for ax in axis):
raise TypeError('an integer is required for the axis')
n_axes = len(self.axes)
for ax in axis:
if ax < -n_axes or ax >= n_axes:
raise ValueError("'axis' entry {0} is out of bounds "
"[-{1}, {1})".format(ax, n_axes))
if any(len(self.axes[ax]) != 1 for ax in axis):
raise ValueError('cannot select an axis to squeeze out which '
'has size not equal to one')

try:
return self.iloc[tuple([0 if len(a) == 1 else slice(None)
for a in self.axes])]
return self.iloc[
tuple([0 if len(a) == 1 and i in axis else slice(None)
for i, a in enumerate(self.axes)])]
except:
return self

Expand Down
16 changes: 12 additions & 4 deletions pandas/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,17 +1770,25 @@ def test_squeeze(self):
[tm.assert_series_equal(empty_series, higher_dim.squeeze())
for higher_dim in [empty_series, empty_frame, empty_panel]]

# axis argument
df = tm.makeTimeDataFrame(nper=1).iloc[:, :1]
tm.assert_equal(df.shape, (1, 1))
tm.assert_series_equal(df.squeeze(axis=0), df.iloc[0])
tm.assert_series_equal(df.squeeze(axis=1), df.iloc[:, 0])
tm.assert_equal(df.squeeze(), df.iloc[0, 0])
tm.assertRaises(ValueError, df.squeeze, axis=2)
tm.assertRaises(TypeError, df.squeeze, axis='x')

df = tm.makeTimeDataFrame(3)
tm.assertRaises(ValueError, df.squeeze, axis=0)

def test_numpy_squeeze(self):
s = tm.makeFloatSeries()
tm.assert_series_equal(np.squeeze(s), s)

df = tm.makeTimeDataFrame().reindex(columns=['A'])
tm.assert_series_equal(np.squeeze(df), df['A'])

msg = "the 'axis' parameter is not supported"
tm.assertRaisesRegexp(ValueError, msg,
np.squeeze, s, axis=0)

def test_transpose(self):
msg = (r"transpose\(\) got multiple values for "
r"keyword argument 'axes'")
Expand Down