diff --git a/doc/source/basics.rst b/doc/source/basics.rst index 134cc5106015b..d8b1602fb104d 100644 --- a/doc/source/basics.rst +++ b/doc/source/basics.rst @@ -2229,7 +2229,3 @@ All numpy dtypes are subclasses of ``numpy.generic``: Pandas also defines the types ``category``, and ``datetime64[ns, tz]``, which are not integrated into the normal numpy hierarchy and wont show up with the above function. - -.. note:: - - The ``include`` and ``exclude`` parameters must be non-string sequences. diff --git a/doc/source/style.ipynb b/doc/source/style.ipynb index 4eeda491426b1..c250787785e14 100644 --- a/doc/source/style.ipynb +++ b/doc/source/style.ipynb @@ -935,7 +935,7 @@ "\n", "*Experimental: This is a new feature and still under development. We'll be adding features and possibly making breaking changes in future releases. We'd love to hear your feedback.*\n", "\n", - "Some support is available for exporting styled `DataFrames` to Excel worksheets using the `OpenPyXL` engine. CSS2.2 properties handled include:\n", + "Some support is available for exporting styled `DataFrames` to Excel worksheets using the `OpenPyXL` engine. CSS2.2 properties handled include:\n", "\n", "- `background-color`\n", "- `border-style`, `border-width`, `border-color` and their {`top`, `right`, `bottom`, `left` variants}\n", diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index d5cc3d6ddca8e..6968bbebc836c 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -39,6 +39,7 @@ Other Enhancements - :func:`read_feather` has gained the ``nthreads`` parameter for multi-threaded operations (:issue:`16359`) - :func:`DataFrame.clip()` and :func:`Series.clip()` have gained an ``inplace`` argument. (:issue:`15388`) - :func:`crosstab` has gained a ``margins_name`` parameter to define the name of the row / column that will contain the totals when ``margins=True``. (:issue:`15972`) +- :func:`Dataframe.select_dtypes` now accepts scalar values for include/exclude as well as list-like. (:issue:`16855`) .. _whatsnew_0210.api_breaking: diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 80cdebc24c39d..6559fc4c24ce2 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -2285,9 +2285,9 @@ def select_dtypes(self, include=None, exclude=None): Parameters ---------- - include, exclude : list-like - A list of dtypes or strings to be included/excluded. You must pass - in a non-empty sequence for at least one of these. + include, exclude : scalar or list-like + A selection of dtypes or strings to be included/excluded. At least + one of these parameters must be supplied. Raises ------ @@ -2295,8 +2295,6 @@ def select_dtypes(self, include=None, exclude=None): * If both of ``include`` and ``exclude`` are empty * If ``include`` and ``exclude`` have overlapping elements * If any kind of string dtype is passed in. - TypeError - * If either of ``include`` or ``exclude`` is not a sequence Returns ------- @@ -2331,6 +2329,14 @@ def select_dtypes(self, include=None, exclude=None): 3 0.0764 False 2 4 -0.9703 True 1 5 -1.2094 False 2 + >>> df.select_dtypes(include='bool') + c + 0 True + 1 False + 2 True + 3 False + 4 True + 5 False >>> df.select_dtypes(include=['float64']) c 0 1 @@ -2348,10 +2354,12 @@ def select_dtypes(self, include=None, exclude=None): 4 True 5 False """ - include, exclude = include or (), exclude or () - if not (is_list_like(include) and is_list_like(exclude)): - raise TypeError('include and exclude must both be non-string' - ' sequences') + + if not is_list_like(include): + include = (include,) if include is not None else () + if not is_list_like(exclude): + exclude = (exclude,) if exclude is not None else () + selection = tuple(map(frozenset, (include, exclude))) if not any(selection): diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index 335b76ff2aade..065580d56a683 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -104,7 +104,7 @@ def test_dtypes_are_correct_after_column_slice(self): ('b', np.float_), ('c', np.float_)]))) - def test_select_dtypes_include(self): + def test_select_dtypes_include_using_list_like(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), 'c': np.arange(3, 6).astype('u1'), @@ -145,14 +145,10 @@ def test_select_dtypes_include(self): ei = df[['h', 'i']] assert_frame_equal(ri, ei) - ri = df.select_dtypes(include=['timedelta']) - ei = df[['k']] - assert_frame_equal(ri, ei) - pytest.raises(NotImplementedError, lambda: df.select_dtypes(include=['period'])) - def test_select_dtypes_exclude(self): + def test_select_dtypes_exclude_using_list_like(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), 'c': np.arange(3, 6).astype('u1'), @@ -162,7 +158,7 @@ def test_select_dtypes_exclude(self): ee = df[['a', 'e']] assert_frame_equal(re, ee) - def test_select_dtypes_exclude_include(self): + def test_select_dtypes_exclude_include_using_list_like(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), 'c': np.arange(3, 6).astype('u1'), @@ -181,6 +177,114 @@ def test_select_dtypes_exclude_include(self): e = df[['b', 'e']] assert_frame_equal(r, e) + def test_select_dtypes_include_using_scalars(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(include=np.number) + ei = df[['b', 'c', 'd', 'k']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include='datetime') + ei = df[['g']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include='datetime64') + ei = df[['g']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include='category') + ei = df[['f']] + assert_frame_equal(ri, ei) + + pytest.raises(NotImplementedError, + lambda: df.select_dtypes(include='period')) + + def test_select_dtypes_exclude_using_scalars(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(exclude=np.number) + ei = df[['a', 'e', 'f', 'g', 'h', 'i', 'j']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(exclude='category') + ei = df[['a', 'b', 'c', 'd', 'e', 'g', 'h', 'i', 'j', 'k']] + assert_frame_equal(ri, ei) + + pytest.raises(NotImplementedError, + lambda: df.select_dtypes(exclude='period')) + + def test_select_dtypes_include_exclude_using_scalars(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(include=np.number, exclude='floating') + ei = df[['b', 'c', 'k']] + assert_frame_equal(ri, ei) + + def test_select_dtypes_include_exclude_mixed_scalars_lists(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(include=np.number, + exclude=['floating', 'timedelta']) + ei = df[['b', 'c']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include=[np.number, 'category'], + exclude='floating') + ei = df[['b', 'c', 'f', 'k']] + assert_frame_equal(ri, ei) + def test_select_dtypes_not_an_attr_but_still_valid_dtype(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), @@ -205,18 +309,6 @@ def test_select_dtypes_empty(self): 'must be nonempty'): df.select_dtypes() - def test_select_dtypes_raises_on_string(self): - df = DataFrame({'a': list('abc'), 'b': list(range(1, 4))}) - with tm.assert_raises_regex(TypeError, 'include and exclude ' - '.+ non-'): - df.select_dtypes(include='object') - with tm.assert_raises_regex(TypeError, 'include and exclude ' - '.+ non-'): - df.select_dtypes(exclude='object') - with tm.assert_raises_regex(TypeError, 'include and exclude ' - '.+ non-'): - df.select_dtypes(include=int, exclude='object') - def test_select_dtypes_bad_datetime64(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)),