Skip to content

Commit 06fd41c

Browse files
Evan Wrightevanpw
Evan Wright
authored andcommitted
Allow clip, clip_lower, and clip_upper to use array-like thresholds (GH 6966)
1 parent dcc7431 commit 06fd41c

File tree

4 files changed

+83
-15
lines changed

4 files changed

+83
-15
lines changed

pandas/core/generic.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -2809,14 +2809,16 @@ def notnull(self):
28092809
"""
28102810
return notnull(self).__finalize__(self)
28112811

2812-
def clip(self, lower=None, upper=None, out=None):
2812+
def clip(self, lower=None, upper=None, out=None, axis=None):
28132813
"""
28142814
Trim values at input threshold(s)
28152815
28162816
Parameters
28172817
----------
2818-
lower : float, default None
2819-
upper : float, default None
2818+
lower : float or array_like, default None
2819+
upper : float or array_like, default None
2820+
axis : int or string axis name, optional
2821+
Align object with lower and upper along the given axis.
28202822
28212823
Returns
28222824
-------
@@ -2827,19 +2829,26 @@ def clip(self, lower=None, upper=None, out=None):
28272829

28282830
# GH 2747 (arguments were reversed)
28292831
if lower is not None and upper is not None:
2830-
lower, upper = min(lower, upper), max(lower, upper)
2832+
if lib.isscalar(lower) and lib.isscalar(upper):
2833+
lower, upper = min(lower, upper), max(lower, upper)
28312834

28322835
result = self
28332836
if lower is not None:
2834-
result = result.clip_lower(lower)
2837+
result = result.clip_lower(lower, axis)
28352838
if upper is not None:
2836-
result = result.clip_upper(upper)
2839+
result = result.clip_upper(upper, axis)
28372840

28382841
return result
28392842

2840-
def clip_upper(self, threshold):
2843+
def clip_upper(self, threshold, axis=None):
28412844
"""
2842-
Return copy of input with values above given value truncated
2845+
Return copy of input with values above given value(s) truncated
2846+
2847+
Parameters
2848+
----------
2849+
threshold : float or array_like
2850+
axis : int or string axis name, optional
2851+
Align object with threshold along the given axis.
28432852
28442853
See also
28452854
--------
@@ -2849,14 +2858,21 @@ def clip_upper(self, threshold):
28492858
-------
28502859
clipped : same type as input
28512860
"""
2852-
if isnull(threshold):
2861+
if np.any(isnull(threshold)):
28532862
raise ValueError("Cannot use an NA value as a clip threshold")
28542863

2855-
return self.where((self <= threshold) | isnull(self), threshold)
2864+
subset = self.le(threshold, axis=axis) | isnull(self)
2865+
return self.where(subset, threshold, axis=axis)
28562866

2857-
def clip_lower(self, threshold):
2867+
def clip_lower(self, threshold, axis=None):
28582868
"""
2859-
Return copy of the input with values below given value truncated
2869+
Return copy of the input with values below given value(s) truncated
2870+
2871+
Parameters
2872+
----------
2873+
threshold : float or array_like
2874+
axis : int or string axis name, optional
2875+
Align object with threshold along the given axis.
28602876
28612877
See also
28622878
--------
@@ -2866,10 +2882,11 @@ def clip_lower(self, threshold):
28662882
-------
28672883
clipped : same type as input
28682884
"""
2869-
if isnull(threshold):
2885+
if np.any(isnull(threshold)):
28702886
raise ValueError("Cannot use an NA value as a clip threshold")
28712887

2872-
return self.where((self >= threshold) | isnull(self), threshold)
2888+
subset = self.ge(threshold, axis=axis) | isnull(self)
2889+
return self.where(subset, threshold, axis=axis)
28732890

28742891
def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
28752892
group_keys=True, squeeze=False):

pandas/core/ops.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,11 @@ def na_op(x, y):
571571

572572
return result
573573

574-
def wrapper(self, other):
574+
def wrapper(self, other, axis=None):
575+
# Validate the axis parameter
576+
if axis is not None:
577+
self._get_axis_number(axis)
578+
575579
if isinstance(other, pd.Series):
576580
name = _maybe_match_name(self, other)
577581
if len(self) != len(other):

pandas/tests/test_frame.py

+33
Original file line numberDiff line numberDiff line change
@@ -11253,6 +11253,39 @@ def test_dataframe_clip(self):
1125311253
self.assertTrue((clipped_df.values[ub_mask] == ub).all() == True)
1125411254
self.assertTrue((clipped_df.values[mask] == df.values[mask]).all() == True)
1125511255

11256+
def test_clip_against_series(self):
11257+
# GH #6966
11258+
11259+
df = DataFrame(np.random.randn(1000, 2))
11260+
lb = Series(np.random.randn(1000))
11261+
ub = lb + 1
11262+
11263+
clipped_df = df.clip(lb, ub, axis=0)
11264+
11265+
for i in range(2):
11266+
lb_mask = df.iloc[:, i] <= lb
11267+
ub_mask = df.iloc[:, i] >= ub
11268+
mask = ~lb_mask & ~ub_mask
11269+
11270+
assert_series_equal(clipped_df.loc[lb_mask, i], lb[lb_mask])
11271+
assert_series_equal(clipped_df.loc[ub_mask, i], ub[ub_mask])
11272+
assert_series_equal(clipped_df.loc[mask, i], df.loc[mask, i])
11273+
11274+
def test_clip_against_frame(self):
11275+
df = DataFrame(np.random.randn(1000, 2))
11276+
lb = DataFrame(np.random.randn(1000, 2))
11277+
ub = lb + 1
11278+
11279+
clipped_df = df.clip(lb, ub)
11280+
11281+
lb_mask = df <= lb
11282+
ub_mask = df >= ub
11283+
mask = ~lb_mask & ~ub_mask
11284+
11285+
assert_frame_equal(clipped_df[lb_mask], lb[lb_mask])
11286+
assert_frame_equal(clipped_df[ub_mask], ub[ub_mask])
11287+
assert_frame_equal(clipped_df[mask], df[mask])
11288+
1125611289
def test_get_X_columns(self):
1125711290
# numeric and object columns
1125811291

pandas/tests/test_series.py

+14
Original file line numberDiff line numberDiff line change
@@ -4884,6 +4884,20 @@ def test_clip_types_and_nulls(self):
48844884
self.assertEqual(list(isnull(s)), list(isnull(l)))
48854885
self.assertEqual(list(isnull(s)), list(isnull(u)))
48864886

4887+
def test_clip_against_series(self):
4888+
# GH #6966
4889+
4890+
s = Series([1.0, 1.0, 4.0])
4891+
threshold = Series([1.0, 2.0, 3.0])
4892+
4893+
assert_series_equal(s.clip_lower(threshold), Series([1.0, 2.0, 4.0]))
4894+
assert_series_equal(s.clip_upper(threshold), Series([1.0, 1.0, 3.0]))
4895+
4896+
lower = Series([1.0, 2.0, 3.0])
4897+
upper = Series([1.5, 2.5, 3.5])
4898+
assert_series_equal(s.clip(lower, upper), Series([1.0, 2.0, 3.5]))
4899+
assert_series_equal(s.clip(1.5, upper), Series([1.5, 1.5, 3.5]))
4900+
48874901
def test_valid(self):
48884902
ts = self.ts.copy()
48894903
ts[::2] = np.NaN

0 commit comments

Comments
 (0)