Skip to content

Commit ddbdb7a

Browse files
committed
ENH: Allow usecols to accept callable (GH14154)
1 parent b1d9599 commit ddbdb7a

File tree

4 files changed

+79
-24
lines changed

4 files changed

+79
-24
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Other enhancements
3131
^^^^^^^^^^^^^^^^^^
3232

3333
- ``pd.read_excel`` now preserves sheet order when using ``sheetname=None`` (:issue:`9930`)
34+
- The ``usecols`` argument now accepts a callable function as a value (:issue:`14154`)
3435

3536

3637
.. _whatsnew_0200.api_breaking:

pandas/io/parsers.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pandas.util.decorators import Appender
3535

3636
import pandas.lib as lib
37+
import pandas.core.common as com
3738
import pandas.parser as _parser
3839

3940

@@ -86,13 +87,14 @@
8687
MultiIndex is used. If you have a malformed file with delimiters at the end
8788
of each line, you might consider index_col=False to force pandas to _not_
8889
use the first column as the index (row names)
89-
usecols : array-like, default None
90-
Return a subset of the columns. All elements in this array must either
90+
usecols : array-like or callable, default None
91+
Return a subset of the columns. If array-like, all elements must either
9192
be positional (i.e. integer indices into the document columns) or strings
9293
that correspond to column names provided either by the user in `names` or
9394
inferred from the document header row(s). For example, a valid `usecols`
94-
parameter would be [0, 1, 2] or ['foo', 'bar', 'baz']. Using this parameter
95-
results in much faster parsing time and lower memory usage.
95+
parameter would be [0, 1, 2], ['foo', 'bar', 'baz'] or lambda x: x.upper()
96+
in ['AAA', 'BBB', 'DDD']. Using this parameter results in much faster
97+
parsing time and lower memory usage.
9698
as_recarray : boolean, default False
9799
DEPRECATED: this argument will be removed in a future version. Please call
98100
`pd.read_csv(...).to_records()` instead.
@@ -976,17 +978,27 @@ def _is_index_col(col):
976978
return col is not None and col is not False
977979

978980

981+
def _evaluate_usecols(usecols, names):
982+
if callable(usecols):
983+
return set([i for i, name in enumerate(names)
984+
if com._apply_if_callable(usecols, name)])
985+
else:
986+
return usecols
987+
988+
979989
def _validate_usecols_arg(usecols):
980990
"""
981991
Check whether or not the 'usecols' parameter
982992
contains all integers (column selection by index)
983993
or strings (column by name). Raises a ValueError
984994
if that is not the case.
985995
"""
986-
msg = ("The elements of 'usecols' must "
987-
"either be all strings, all unicode, or all integers")
996+
msg = ("'usecols' must either be all strings, all unicode, "
997+
"all integers or callable")
988998

989999
if usecols is not None:
1000+
if callable(usecols):
1001+
return usecols
9901002
usecols_dtype = lib.infer_dtype(usecols)
9911003
if usecols_dtype not in ('empty', 'integer',
9921004
'string', 'unicode'):
@@ -1426,11 +1438,12 @@ def __init__(self, src, **kwds):
14261438
self.orig_names = self.names[:]
14271439

14281440
if self.usecols:
1429-
if len(self.names) > len(self.usecols):
1441+
usecols = _evaluate_usecols(self.usecols, self.orig_names)
1442+
if len(self.names) > len(usecols):
14301443
self.names = [n for i, n in enumerate(self.names)
1431-
if (i in self.usecols or n in self.usecols)]
1444+
if (i in usecols or n in usecols)]
14321445

1433-
if len(self.names) < len(self.usecols):
1446+
if len(self.names) < len(usecols):
14341447
raise ValueError("Usecols do not match names.")
14351448

14361449
self._set_noconvert_columns()
@@ -1592,9 +1605,10 @@ def read(self, nrows=None):
15921605

15931606
def _filter_usecols(self, names):
15941607
# hackish
1595-
if self.usecols is not None and len(names) != len(self.usecols):
1608+
usecols = _evaluate_usecols(self.usecols, names)
1609+
if usecols is not None and len(names) != len(usecols):
15961610
names = [name for i, name in enumerate(names)
1597-
if i in self.usecols or name in self.usecols]
1611+
if i in usecols or name in usecols]
15981612
return names
15991613

16001614
def _get_index_names(self):
@@ -2207,7 +2221,9 @@ def _handle_usecols(self, columns, usecols_key):
22072221
usecols_key is used if there are string usecols.
22082222
"""
22092223
if self.usecols is not None:
2210-
if any([isinstance(col, string_types) for col in self.usecols]):
2224+
if callable(self.usecols):
2225+
col_indices = _evaluate_usecols(self.usecols, usecols_key)
2226+
elif any([isinstance(u, string_types) for u in self.usecols]):
22112227
if len(columns) > 1:
22122228
raise ValueError("If using multiple headers, usecols must "
22132229
"be integers.")

pandas/io/tests/parser/usecols.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def test_raise_on_mixed_dtype_usecols(self):
2323
1000,2000,3000
2424
4000,5000,6000
2525
"""
26-
msg = ("The elements of 'usecols' must "
27-
"either be all strings, all unicode, or all integers")
26+
27+
msg = ("'usecols' must either be all strings, all unicode, "
28+
"all integers or callable")
2829
usecols = [0, 'b', 2]
2930

3031
with tm.assertRaisesRegexp(ValueError, msg):
@@ -302,8 +303,8 @@ def test_usecols_with_mixed_encoding_strings(self):
302303
3.568935038,7,False,a
303304
'''
304305

305-
msg = ("The elements of 'usecols' must "
306-
"either be all strings, all unicode, or all integers")
306+
msg = ("'usecols' must either be all strings, all unicode, "
307+
"all integers or callable")
307308

308309
with tm.assertRaisesRegexp(ValueError, msg):
309310
self.read_csv(StringIO(s), usecols=[u'AAA', b'BBB'])
@@ -366,3 +367,25 @@ def test_np_array_usecols(self):
366367
expected = DataFrame([[1, 2]], columns=usecols)
367368
result = self.read_csv(StringIO(data), usecols=usecols)
368369
tm.assert_frame_equal(result, expected)
370+
371+
def test_callable_usecols(self):
372+
s = '''AaA,bBb,CCC,ddd
373+
0.056674973,8,True,a
374+
2.613230982,2,False,b
375+
3.568935038,7,False,a
376+
'''
377+
378+
data = {
379+
'AaA': {
380+
0: 0.056674972999999997,
381+
1: 2.6132309819999997,
382+
2: 3.5689350380000002
383+
},
384+
'bBb': {0: 8, 1: 2, 2: 7},
385+
'ddd': {0: 'a', 1: 'b', 2: 'a'}
386+
}
387+
expected = DataFrame(data)
388+
389+
df = self.read_csv(StringIO(s), usecols=lambda x:
390+
x.upper() in ['AAA', 'BBB', 'DDD'])
391+
tm.assert_frame_equal(df, expected)

pandas/parser.pyx

+23-8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ cimport util
3838

3939
import pandas.lib as lib
4040
import pandas.compat as compat
41+
import pandas.core.common as com
4142
from pandas.types.common import (is_categorical_dtype, CategoricalDtype,
4243
is_integer_dtype, is_float_dtype,
4344
is_bool_dtype, is_object_dtype,
@@ -276,7 +277,7 @@ cdef class TextReader:
276277
object file_handle, na_fvalues
277278
object true_values, false_values
278279
object handle
279-
bint na_filter, verbose, has_usecols, has_mi_columns
280+
bint na_filter, verbose, has_usecols, has_mi_columns, callable_usecols
280281
int parser_start
281282
list clocks
282283
char *c_encoding
@@ -300,8 +301,10 @@ cdef class TextReader:
300301
object compression
301302
object mangle_dupe_cols
302303
object tupleize_cols
304+
object usecols
303305
list dtype_cast_order
304-
set noconvert, usecols
306+
set noconvert
307+
305308

306309
def __cinit__(self, source,
307310
delimiter=b',',
@@ -437,7 +440,11 @@ cdef class TextReader:
437440
# suboptimal
438441
if usecols is not None:
439442
self.has_usecols = 1
440-
self.usecols = set(usecols)
443+
if callable(usecols):
444+
self.callable_usecols = 1
445+
self.usecols = usecols
446+
else:
447+
self.usecols = set(usecols)
441448

442449
# XXX
443450
if skipfooter > 0:
@@ -701,7 +708,6 @@ cdef class TextReader:
701708
cdef StringPath path = _string_path(self.c_encoding)
702709

703710
header = []
704-
705711
if self.parser.header_start >= 0:
706712

707713
# Header is in the file
@@ -821,7 +827,7 @@ cdef class TextReader:
821827
# 'data has %d fields'
822828
# % (passed_count, field_count))
823829

824-
if self.has_usecols and self.allow_leading_cols:
830+
if self.has_usecols and self.allow_leading_cols and not self.callable_usecols:
825831
nuse = len(self.usecols)
826832
if nuse == passed_count:
827833
self.leading_cols = 0
@@ -1015,17 +1021,25 @@ cdef class TextReader:
10151021

10161022
results = {}
10171023
nused = 0
1024+
10181025
for i in range(self.table_width):
1026+
10191027
if i < self.leading_cols:
10201028
# Pass through leading columns always
10211029
name = i
1022-
elif self.usecols and nused == len(self.usecols):
1030+
elif self.usecols and not self.callable_usecols and nused == len(self.usecols):
10231031
# Once we've gathered all requested columns, stop. GH5766
10241032
break
10251033
else:
10261034
name = self._get_column_name(i, nused)
1027-
if self.has_usecols and not (i in self.usecols or
1028-
name in self.usecols):
1035+
usecols = set()
1036+
if self.callable_usecols:
1037+
if com._apply_if_callable(self.usecols, name):
1038+
usecols = set([i])
1039+
else:
1040+
usecols = self.usecols
1041+
if self.has_usecols and not (i in usecols or
1042+
name in usecols):
10291043
continue
10301044
nused += 1
10311045

@@ -1341,6 +1355,7 @@ def _maybe_upcast(arr):
13411355

13421356
return arr
13431357

1358+
13441359
cdef enum StringPath:
13451360
CSTRING
13461361
UTF8

0 commit comments

Comments
 (0)