Skip to content

Commit 7bbd031

Browse files
sinhrksjreback
authored andcommitted
ENH: Allow where/mask/Indexers to accept callable
closes #12533 closes #11485 Author: sinhrks <[email protected]> Closes #12539 from sinhrks/where and squashes the following commits: 6b5d618 [sinhrks] ENH: Allow .where to accept callable as condition
1 parent a615dbe commit 7bbd031

File tree

13 files changed

+588
-23
lines changed

13 files changed

+588
-23
lines changed

doc/source/indexing.rst

+79-14
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ of multi-axis indexing.
7979
- A slice object with labels ``'a':'f'``, (note that contrary to usual python
8080
slices, **both** the start and the stop are included!)
8181
- A boolean array
82+
- A ``callable`` function with one argument (the calling Series, DataFrame or Panel) and
83+
that returns valid output for indexing (one of the above)
84+
85+
.. versionadded:: 0.18.1
8286

8387
See more at :ref:`Selection by Label <indexing.label>`
8488

@@ -93,6 +97,10 @@ of multi-axis indexing.
9397
- A list or array of integers ``[4, 3, 0]``
9498
- A slice object with ints ``1:7``
9599
- A boolean array
100+
- A ``callable`` function with one argument (the calling Series, DataFrame or Panel) and
101+
that returns valid output for indexing (one of the above)
102+
103+
.. versionadded:: 0.18.1
96104

97105
See more at :ref:`Selection by Position <indexing.integer>`
98106

@@ -110,6 +118,8 @@ of multi-axis indexing.
110118
See more at :ref:`Advanced Indexing <advanced>` and :ref:`Advanced
111119
Hierarchical <advanced.advanced_hierarchical>`.
112120

121+
- ``.loc``, ``.iloc``, ``.ix`` and also ``[]`` indexing can accept a ``callable`` as indexer. See more at :ref:`Selection By Callable <indexing.callable>`.
122+
113123
Getting values from an object with multi-axes selection uses the following
114124
notation (using ``.loc`` as an example, but applies to ``.iloc`` and ``.ix`` as
115125
well). Any of the axes accessors may be the null slice ``:``. Axes left out of
@@ -317,6 +327,7 @@ The ``.loc`` attribute is the primary access method. The following are valid inp
317327
- A list or array of labels ``['a', 'b', 'c']``
318328
- A slice object with labels ``'a':'f'`` (note that contrary to usual python slices, **both** the start and the stop are included!)
319329
- A boolean array
330+
- A ``callable``, see :ref:`Selection By Callable <indexing.callable>`
320331

321332
.. ipython:: python
322333
@@ -340,13 +351,13 @@ With a DataFrame
340351
index=list('abcdef'),
341352
columns=list('ABCD'))
342353
df1
343-
df1.loc[['a','b','d'],:]
354+
df1.loc[['a', 'b', 'd'], :]
344355
345356
Accessing via label slices
346357

347358
.. ipython:: python
348359
349-
df1.loc['d':,'A':'C']
360+
df1.loc['d':, 'A':'C']
350361
351362
For getting a cross section using a label (equiv to ``df.xs('a')``)
352363

@@ -358,15 +369,15 @@ For getting values with a boolean array
358369

359370
.. ipython:: python
360371
361-
df1.loc['a']>0
362-
df1.loc[:,df1.loc['a']>0]
372+
df1.loc['a'] > 0
373+
df1.loc[:, df1.loc['a'] > 0]
363374
364375
For getting a value explicitly (equiv to deprecated ``df.get_value('a','A')``)
365376

366377
.. ipython:: python
367378
368379
# this is also equivalent to ``df1.at['a','A']``
369-
df1.loc['a','A']
380+
df1.loc['a', 'A']
370381
371382
.. _indexing.integer:
372383

@@ -387,6 +398,7 @@ The ``.iloc`` attribute is the primary access method. The following are valid in
387398
- A list or array of integers ``[4, 3, 0]``
388399
- A slice object with ints ``1:7``
389400
- A boolean array
401+
- A ``callable``, see :ref:`Selection By Callable <indexing.callable>`
390402

391403
.. ipython:: python
392404
@@ -416,26 +428,26 @@ Select via integer slicing
416428
.. ipython:: python
417429
418430
df1.iloc[:3]
419-
df1.iloc[1:5,2:4]
431+
df1.iloc[1:5, 2:4]
420432
421433
Select via integer list
422434

423435
.. ipython:: python
424436
425-
df1.iloc[[1,3,5],[1,3]]
437+
df1.iloc[[1, 3, 5], [1, 3]]
426438
427439
.. ipython:: python
428440
429-
df1.iloc[1:3,:]
441+
df1.iloc[1:3, :]
430442
431443
.. ipython:: python
432444
433-
df1.iloc[:,1:3]
445+
df1.iloc[:, 1:3]
434446
435447
.. ipython:: python
436448
437449
# this is also equivalent to ``df1.iat[1,1]``
438-
df1.iloc[1,1]
450+
df1.iloc[1, 1]
439451
440452
For getting a cross section using an integer position (equiv to ``df.xs(1)``)
441453

@@ -471,8 +483,8 @@ returned)
471483
472484
dfl = pd.DataFrame(np.random.randn(5,2), columns=list('AB'))
473485
dfl
474-
dfl.iloc[:,2:3]
475-
dfl.iloc[:,1:3]
486+
dfl.iloc[:, 2:3]
487+
dfl.iloc[:, 1:3]
476488
dfl.iloc[4:6]
477489
478490
A single indexer that is out of bounds will raise an ``IndexError``.
@@ -481,12 +493,52 @@ A list of indexers where any element is out of bounds will raise an
481493

482494
.. code-block:: python
483495
484-
dfl.iloc[[4,5,6]]
496+
dfl.iloc[[4, 5, 6]]
485497
IndexError: positional indexers are out-of-bounds
486498
487-
dfl.iloc[:,4]
499+
dfl.iloc[:, 4]
488500
IndexError: single positional indexer is out-of-bounds
489501
502+
.. _indexing.callable:
503+
504+
Selection By Callable
505+
---------------------
506+
507+
.. versionadded:: 0.18.1
508+
509+
``.loc``, ``.iloc``, ``.ix`` and also ``[]`` indexing can accept a ``callable`` as indexer.
510+
The ``callable`` must be a function with one argument (the calling Series, DataFrame or Panel) and that returns valid output for indexing.
511+
512+
.. ipython:: python
513+
514+
df1 = pd.DataFrame(np.random.randn(6, 4),
515+
index=list('abcdef'),
516+
columns=list('ABCD'))
517+
df1
518+
519+
df1.loc[lambda df: df.A > 0, :]
520+
df1.loc[:, lambda df: ['A', 'B']]
521+
522+
df1.iloc[:, lambda df: [0, 1]]
523+
524+
df1[lambda df: df.columns[0]]
525+
526+
527+
You can use callable indexing in ``Series``.
528+
529+
.. ipython:: python
530+
531+
df1.A.loc[lambda s: s > 0]
532+
533+
Using these methods / indexers, you can chain data selection operations
534+
without using temporary variable.
535+
536+
.. ipython:: python
537+
538+
bb = pd.read_csv('data/baseball.csv', index_col='id')
539+
(bb.groupby(['year', 'team']).sum()
540+
.loc[lambda df: df.r > 100])
541+
490542
.. _indexing.basics.partial_setting:
491543

492544
Selecting Random Samples
@@ -848,6 +900,19 @@ This is equivalent (but faster than) the following.
848900
df2 = df.copy()
849901
df.apply(lambda x, y: x.where(x>0,y), y=df['A'])
850902
903+
.. versionadded:: 0.18.1
904+
905+
Where can accept a callable as condition and ``other`` arguments. The function must
906+
be with one argument (the calling Series or DataFrame) and that returns valid output
907+
as condition and ``other`` argument.
908+
909+
.. ipython:: python
910+
911+
df3 = pd.DataFrame({'A': [1, 2, 3],
912+
'B': [4, 5, 6],
913+
'C': [7, 8, 9]})
914+
df3.where(lambda x: x > 4, lambda x: x + 10)
915+
851916
**mask**
852917

853918
``mask`` is the inverse boolean operation of ``where``.

doc/source/whatsnew/v0.18.1.txt

+62
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Highlights include:
1313
- ``pd.to_datetime()`` has gained the ability to assemble dates from a ``DataFrame``, see :ref:`here <whatsnew_0181.enhancements.assembling>`
1414
- Custom business hour offset, see :ref:`here <whatsnew_0181.enhancements.custombusinesshour>`.
1515
- Many bug fixes in the handling of ``sparse``, see :ref:`here <whatsnew_0181.sparse>`
16+
- Method chaining improvements, see :ref:`here <whatsnew_0181.enhancements.method_chain>`.
17+
1618

1719
.. contents:: What's new in v0.18.1
1820
:local:
@@ -94,6 +96,66 @@ Now you can do:
9496

9597
df.groupby('group').resample('1D').ffill()
9698

99+
.. _whatsnew_0181.enhancements.method_chain:
100+
101+
Method chaininng improvements
102+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
103+
104+
The following methods / indexers now accept ``callable``. It is intended to make
105+
these more useful in method chains, see :ref:`Selection By Callable <indexing.callable>`.
106+
(:issue:`11485`, :issue:`12533`)
107+
108+
- ``.where()`` and ``.mask()``
109+
- ``.loc[]``, ``iloc[]`` and ``.ix[]``
110+
- ``[]`` indexing
111+
112+
``.where()`` and ``.mask()``
113+
""""""""""""""""""""""""""""
114+
115+
These can accept a callable as condition and ``other``
116+
arguments.
117+
118+
.. ipython:: python
119+
120+
df = pd.DataFrame({'A': [1, 2, 3],
121+
'B': [4, 5, 6],
122+
'C': [7, 8, 9]})
123+
df.where(lambda x: x > 4, lambda x: x + 10)
124+
125+
``.loc[]``, ``.iloc[]``, ``.ix[]``
126+
""""""""""""""""""""""""""""""""""
127+
128+
These can accept a callable, and tuple of callable as a slicer. The callable
129+
can return valid ``bool`` indexer or anything which is valid for these indexer's input.
130+
131+
.. ipython:: python
132+
133+
# callable returns bool indexer
134+
df.loc[lambda x: x.A >= 2, lambda x: x.sum() > 10]
135+
136+
# callable returns list of labels
137+
df.loc[lambda x: [1, 2], lambda x: ['A', 'B']]
138+
139+
``[]`` indexing
140+
"""""""""""""""
141+
142+
Finally, you can use a callable in ``[]`` indexing of Series, DataFrame and Panel.
143+
The callable must return valid input for ``[]`` indexing depending on its
144+
class and index type.
145+
146+
.. ipython:: python
147+
148+
df[lambda x: 'A']
149+
150+
Using these methods / indexers, you can chain data selection operations
151+
without using temporary variable.
152+
153+
.. ipython:: python
154+
155+
bb = pd.read_csv('data/baseball.csv', index_col='id')
156+
(bb.groupby(['year', 'team']).sum()
157+
.loc[lambda df: df.r > 100])
158+
97159
.. _whatsnew_0181.partial_string_indexing:
98160

99161
Partial string indexing on ``DateTimeIndex`` when part of a ``MultiIndex``

pandas/core/common.py

+10
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,16 @@ def _get_callable_name(obj):
18431843
return None
18441844

18451845

1846+
def _apply_if_callable(maybe_callable, obj, **kwargs):
1847+
"""
1848+
Evaluate possibly callable input using obj and kwargs if it is callable,
1849+
otherwise return as it is
1850+
"""
1851+
if callable(maybe_callable):
1852+
return maybe_callable(obj, **kwargs)
1853+
return maybe_callable
1854+
1855+
18461856
_string_dtypes = frozenset(map(_get_dtype_from_object, (compat.binary_type,
18471857
compat.text_type)))
18481858

pandas/core/frame.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,7 @@ def iget_value(self, i, j):
19701970
return self.iat[i, j]
19711971

19721972
def __getitem__(self, key):
1973+
key = com._apply_if_callable(key, self)
19731974

19741975
# shortcut if we are an actual column
19751976
is_mi_columns = isinstance(self.columns, MultiIndex)
@@ -2138,6 +2139,9 @@ def query(self, expr, inplace=False, **kwargs):
21382139
>>> df.query('a > b')
21392140
>>> df[df.a > df.b] # same result as the previous expression
21402141
"""
2142+
if not isinstance(expr, compat.string_types):
2143+
msg = "expr must be a string to be evaluated, {0} given"
2144+
raise ValueError(msg.format(type(expr)))
21412145
kwargs['level'] = kwargs.pop('level', 0) + 1
21422146
kwargs['target'] = None
21432147
res = self.eval(expr, **kwargs)
@@ -2336,6 +2340,7 @@ def _box_col_values(self, values, items):
23362340
name=items, fastpath=True)
23372341

23382342
def __setitem__(self, key, value):
2343+
key = com._apply_if_callable(key, self)
23392344

23402345
# see if we can slice the rows
23412346
indexer = convert_to_index_sliceable(self, key)
@@ -2454,8 +2459,9 @@ def assign(self, **kwargs):
24542459
kwargs : keyword, value pairs
24552460
keywords are the column names. If the values are
24562461
callable, they are computed on the DataFrame and
2457-
assigned to the new columns. If the values are
2458-
not callable, (e.g. a Series, scalar, or array),
2462+
assigned to the new columns. The callable must not
2463+
change input DataFrame (though pandas doesn't check it).
2464+
If the values are not callable, (e.g. a Series, scalar, or array),
24592465
they are simply assigned.
24602466
24612467
Returns
@@ -2513,11 +2519,7 @@ def assign(self, **kwargs):
25132519
# do all calculations first...
25142520
results = {}
25152521
for k, v in kwargs.items():
2516-
2517-
if callable(v):
2518-
results[k] = v(data)
2519-
else:
2520-
results[k] = v
2522+
results[k] = com._apply_if_callable(v, data)
25212523

25222524
# ... and then assign
25232525
for k, v in sorted(results.items()):

pandas/core/generic.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -4283,8 +4283,26 @@ def _align_series(self, other, join='outer', axis=None, level=None,
42834283
42844284
Parameters
42854285
----------
4286-
cond : boolean %(klass)s or array
4287-
other : scalar or %(klass)s
4286+
cond : boolean %(klass)s, array or callable
4287+
If cond is callable, it is computed on the %(klass)s and
4288+
should return boolean %(klass)s or array.
4289+
The callable must not change input %(klass)s
4290+
(though pandas doesn't check it).
4291+
4292+
.. versionadded:: 0.18.1
4293+
4294+
A callable can be used as cond.
4295+
4296+
other : scalar, %(klass)s, or callable
4297+
If other is callable, it is computed on the %(klass)s and
4298+
should return scalar or %(klass)s.
4299+
The callable must not change input %(klass)s
4300+
(though pandas doesn't check it).
4301+
4302+
.. versionadded:: 0.18.1
4303+
4304+
A callable can be used as other.
4305+
42884306
inplace : boolean, default False
42894307
Whether to perform the operation in place on the data
42904308
axis : alignment axis if needed, default None
@@ -4304,6 +4322,9 @@ def _align_series(self, other, join='outer', axis=None, level=None,
43044322
def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
43054323
try_cast=False, raise_on_error=True):
43064324

4325+
cond = com._apply_if_callable(cond, self)
4326+
other = com._apply_if_callable(other, self)
4327+
43074328
if isinstance(cond, NDFrame):
43084329
cond, _ = cond.align(self, join='right', broadcast_axis=1)
43094330
else:
@@ -4461,6 +4482,9 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
44614482
@Appender(_shared_docs['where'] % dict(_shared_doc_kwargs, cond="False"))
44624483
def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None,
44634484
try_cast=False, raise_on_error=True):
4485+
4486+
cond = com._apply_if_callable(cond, self)
4487+
44644488
return self.where(~cond, other=other, inplace=inplace, axis=axis,
44654489
level=level, try_cast=try_cast,
44664490
raise_on_error=raise_on_error)

0 commit comments

Comments
 (0)