diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index b2b74e2a70ca9..ae0744f25c57d 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1576,7 +1576,9 @@ def is_numeric_dtype(arr_or_dtype): >>> is_numeric_dtype(np.array([], dtype=np.timedelta64)) False """ - + if is_extension_array_dtype(arr_or_dtype): + dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype) + return dtype._is_numeric return _is_dtype_type( arr_or_dtype, classes_and_not_datetimelike(np.number, np.bool_)) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index df7003ecf000e..254f096fb8020 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -41,8 +41,8 @@ is_bool_dtype, is_datetime64_any_dtype, is_datetime64tz_dtype, is_dict_like, is_dtype_equal, is_extension_array_dtype, is_extension_type, is_float_dtype, is_integer, is_integer_dtype, is_iterator, is_list_like, - is_named_tuple, is_nested_list_like, is_object_dtype, is_scalar, - is_sequence, needs_i8_conversion) + is_named_tuple, is_nested_list_like, is_numeric_dtype, is_object_dtype, + is_scalar, is_sequence, needs_i8_conversion) from pandas.core.dtypes.generic import ( ABCDataFrame, ABCIndexClass, ABCMultiIndex, ABCSeries) from pandas.core.dtypes.missing import isna, notna @@ -3265,6 +3265,19 @@ def _get_info_slice(obj, indexer): for dtypes in (include, exclude): invalidate_string_dtypes(dtypes) + def add_extension_types(dtypes, search_dtype, func): + """Adds bool or numeric extension types to include/exclude""" + extension_dtypes = [dtype.type for dtype in self.dtypes + if is_extension_array_dtype(dtype) and + func(dtype)] + if search_dtype in dtypes: + return frozenset(dtypes.union(extension_dtypes)) + else: + return dtypes + include = add_extension_types(include, np.number, is_numeric_dtype) + exclude = add_extension_types(exclude, np.number, is_numeric_dtype) + include = add_extension_types(include, np.bool_, is_bool_dtype) + exclude = add_extension_types(exclude, np.bool_, is_bool_dtype) # can't both include AND exclude! if not include.isdisjoint(exclude): raise ValueError('include and exclude overlap on {inc_ex}'.format(