Skip to content

PERF: use is_foo_dtype fastpaths #34111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def func(self, other):
"use 'np.asarray(cat) <op> other'."
)

if isinstance(other, ExtensionArray) and needs_i8_conversion(other):
if isinstance(other, ExtensionArray) and needs_i8_conversion(other.dtype):
# We would return NotImplemented here, but that messes up
# ExtensionIndex's wrapped methods
return op(other, self)
Expand Down
29 changes: 16 additions & 13 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class TimelikeOps:

def _round(self, freq, mode, ambiguous, nonexistent):
# round the local times
if is_datetime64tz_dtype(self):
if is_datetime64tz_dtype(self.dtype):
# operate on naive timestamps, then convert back to aware
naive = self.tz_localize(None)
result = naive._round(freq, mode, ambiguous, nonexistent)
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def fillna(self, value=None, method=None, limit=None):
values = values.copy()

new_values = func(values, limit=limit, mask=mask)
if is_datetime64tz_dtype(self):
if is_datetime64tz_dtype(self.dtype):
# we need to pass int64 values to the constructor to avoid
# re-localizing incorrectly
new_values = new_values.view("i8")
Expand Down Expand Up @@ -1379,6 +1379,7 @@ def _time_shift(self, periods, freq=None):

@unpack_zerodim_and_defer("__add__")
def __add__(self, other):
other_dtype = getattr(other, "dtype", None)

# scalar others
if other is NaT:
Expand All @@ -1398,16 +1399,16 @@ def __add__(self, other):
result = self._time_shift(other)

# array-like others
elif is_timedelta64_dtype(other):
elif is_timedelta64_dtype(other_dtype):
# TimedeltaIndex, ndarray[timedelta64]
result = self._add_timedelta_arraylike(other)
elif is_object_dtype(other):
elif is_object_dtype(other_dtype):
# e.g. Array/Index of DateOffset objects
result = self._addsub_object_array(other, operator.add)
elif is_datetime64_dtype(other) or is_datetime64tz_dtype(other):
elif is_datetime64_dtype(other_dtype) or is_datetime64tz_dtype(other_dtype):
# DatetimeIndex, ndarray[datetime64]
return self._add_datetime_arraylike(other)
elif is_integer_dtype(other):
elif is_integer_dtype(other_dtype):
if not is_period_dtype(self.dtype):
raise integer_op_not_supported(self)
result = self._addsub_int_array(other, operator.add)
Expand All @@ -1419,7 +1420,7 @@ def __add__(self, other):
# In remaining cases, this will end up raising TypeError.
return NotImplemented

if is_timedelta64_dtype(result) and isinstance(result, np.ndarray):
if isinstance(result, np.ndarray) and is_timedelta64_dtype(result.dtype):
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray(result)
Expand Down Expand Up @@ -1455,13 +1456,13 @@ def __sub__(self, other):
result = self._sub_period(other)

# array-like others
elif is_timedelta64_dtype(other):
elif is_timedelta64_dtype(other_dtype):
# TimedeltaIndex, ndarray[timedelta64]
result = self._add_timedelta_arraylike(-other)
elif is_object_dtype(other):
elif is_object_dtype(other_dtype):
# e.g. Array/Index of DateOffset objects
result = self._addsub_object_array(other, operator.sub)
elif is_datetime64_dtype(other) or is_datetime64tz_dtype(other):
elif is_datetime64_dtype(other_dtype) or is_datetime64tz_dtype(other_dtype):
# DatetimeIndex, ndarray[datetime64]
result = self._sub_datetime_arraylike(other)
elif is_period_dtype(other_dtype):
Expand All @@ -1475,14 +1476,16 @@ def __sub__(self, other):
# Includes ExtensionArrays, float_dtype
return NotImplemented

if is_timedelta64_dtype(result) and isinstance(result, np.ndarray):
if isinstance(result, np.ndarray) and is_timedelta64_dtype(result.dtype):
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray(result)
return result

def __rsub__(self, other):
if is_datetime64_any_dtype(other) and is_timedelta64_dtype(self.dtype):
other_dtype = getattr(other, "dtype", None)

if is_datetime64_any_dtype(other_dtype) and is_timedelta64_dtype(self.dtype):
# ndarray[datetime64] cannot be subtracted from self, so
# we need to wrap in DatetimeArray/Index and flip the operation
if lib.is_scalar(other):
Expand All @@ -1504,7 +1507,7 @@ def __rsub__(self, other):
raise TypeError(
f"cannot subtract {type(self).__name__} from {type(other).__name__}"
)
elif is_period_dtype(self.dtype) and is_timedelta64_dtype(other):
elif is_period_dtype(self.dtype) and is_timedelta64_dtype(other_dtype):
# TODO: Can we simplify/generalize these cases at all?
raise TypeError(f"cannot subtract {type(self).__name__} from {other.dtype}")
elif is_timedelta64_dtype(self.dtype):
Expand Down
12 changes: 8 additions & 4 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ def _has_same_tz(self, other):
def _assert_tzawareness_compat(self, other):
# adapted from _Timestamp._assert_tzawareness_compat
other_tz = getattr(other, "tzinfo", None)
if is_datetime64tz_dtype(other):
other_dtype = getattr(other, "dtype", None)

if is_datetime64tz_dtype(other_dtype):
# Get tzinfo from Series dtype
other_tz = other.dtype.tz
if other is NaT:
Expand Down Expand Up @@ -1913,8 +1915,9 @@ def sequence_to_dt64ns(

# By this point we are assured to have either a numpy array or Index
data, copy = maybe_convert_dtype(data, copy)
data_dtype = getattr(data, "dtype", None)

if is_object_dtype(data) or is_string_dtype(data):
if is_object_dtype(data_dtype) or is_string_dtype(data_dtype):
# TODO: We do not have tests specific to string-dtypes,
# also complex or categorical or other extension
copy = False
Expand All @@ -1927,15 +1930,16 @@ def sequence_to_dt64ns(
data, dayfirst=dayfirst, yearfirst=yearfirst
)
tz = maybe_infer_tz(tz, inferred_tz)
data_dtype = data.dtype

# `data` may have originally been a Categorical[datetime64[ns, tz]],
# so we need to handle these types.
if is_datetime64tz_dtype(data):
if is_datetime64tz_dtype(data_dtype):
# DatetimeArray -> ndarray
tz = maybe_infer_tz(tz, data.tz)
result = data._data

elif is_datetime64_dtype(data):
elif is_datetime64_dtype(data_dtype):
# tz-naive DatetimeArray or ndarray[datetime64]
data = getattr(data, "_data", data)
if data.dtype != DT64NS_DTYPE:
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ def make_sparse(arr: np.ndarray, kind="block", fill_value=None, dtype=None, copy
mask = notna(arr)
else:
# cast to object comparison to be safe
if is_string_dtype(arr):
if is_string_dtype(arr.dtype):
arr = arr.astype(object)

if is_object_dtype(arr.dtype):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ def maybe_cast_to_datetime(value, dtype, errors: str = "raise"):
# is solved. String data that is passed with a
# datetime64tz is assumed to be naive which should
# be localized to the timezone.
is_dt_string = is_string_dtype(value)
is_dt_string = is_string_dtype(value.dtype)
value = to_datetime(value, errors=errors).array
if is_dt_string:
# Strings here are naive, so directly localize
Expand Down
8 changes: 4 additions & 4 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def get_dtype_kinds(l):
dtype = arr.dtype
if is_categorical_dtype(dtype):
typ = "category"
elif is_sparse(arr):
elif is_sparse(dtype):
typ = "sparse"
elif isinstance(arr, ABCRangeIndex):
typ = "range"
elif is_datetime64tz_dtype(arr):
elif is_datetime64tz_dtype(dtype):
# if to_concat contains different tz,
# the result must be object dtype
typ = str(arr.dtype)
typ = str(dtype)
elif is_datetime64_dtype(dtype):
typ = "datetime"
elif is_timedelta64_dtype(dtype):
Expand All @@ -57,7 +57,7 @@ def get_dtype_kinds(l):
elif is_bool_dtype(dtype):
typ = "bool"
elif is_extension_array_dtype(dtype):
typ = str(arr.dtype)
typ = str(dtype)
else:
typ = dtype.kind
typs.add(typ)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:

# Object arrays can contain None, NaN and NaT.
# string dtypes must be come to this path for NumPy 1.7.1 compat
if is_string_dtype(left) or is_string_dtype(right):
if is_string_dtype(left.dtype) or is_string_dtype(right.dtype):

if not strict_nan:
# isna considers NaN and None to be equivalent.
Expand Down
16 changes: 8 additions & 8 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6907,9 +6907,9 @@ def interpolate(
index = df.index
methods = {"index", "values", "nearest", "time"}
is_numeric_or_datetime = (
is_numeric_dtype(index)
or is_datetime64_any_dtype(index)
or is_timedelta64_dtype(index)
is_numeric_dtype(index.dtype)
or is_datetime64_any_dtype(index.dtype)
or is_timedelta64_dtype(index.dtype)
)
if method not in methods and not is_numeric_or_datetime:
raise ValueError(
Expand Down Expand Up @@ -8588,7 +8588,7 @@ def _align_frame(
right = right.fillna(method=method, axis=fill_axis, limit=limit)

# if DatetimeIndex have different tz, convert to UTC
if is_datetime64tz_dtype(left.index):
if is_datetime64tz_dtype(left.index.dtype):
if left.index.tz != right.index.tz:
if join_index is not None:
left.index = join_index
Expand Down Expand Up @@ -8675,7 +8675,7 @@ def _align_series(

# if DatetimeIndex have different tz, convert to UTC
if is_series or (not is_series and axis == 0):
if is_datetime64tz_dtype(left.index):
if is_datetime64tz_dtype(left.index.dtype):
if left.index.tz != right.index.tz:
if join_index is not None:
left.index = join_index
Expand Down Expand Up @@ -9957,13 +9957,13 @@ def describe_timestamp_1d(data):
return pd.Series(d, index=stat_index, name=data.name)

def describe_1d(data):
if is_bool_dtype(data):
if is_bool_dtype(data.dtype):
return describe_categorical_1d(data)
elif is_numeric_dtype(data):
return describe_numeric_1d(data)
elif is_datetime64_any_dtype(data):
elif is_datetime64_any_dtype(data.dtype):
return describe_timestamp_1d(data)
elif is_timedelta64_dtype(data):
elif is_timedelta64_dtype(data.dtype):
return describe_numeric_1d(data)
else:
return describe_categorical_1d(data)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def value_counts(
lab = lev.take(lab.cat.codes)
llab = lambda lab, inc: lab[inc]._multiindex.codes[-1]

if is_interval_dtype(lab):
if is_interval_dtype(lab.dtype):
# TODO: should we do this inside II?
sorter = np.lexsort((lab.left, lab.right, ids))
else:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,12 @@ def _cython_operation(
# are not setup for dim transforming
if is_categorical_dtype(values.dtype) or is_sparse(values.dtype):
raise NotImplementedError(f"{values.dtype} dtype not supported")
elif is_datetime64_any_dtype(values):
elif is_datetime64_any_dtype(values.dtype):
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(values):
elif is_timedelta64_dtype(values.dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_interval_dtype,
is_list_like,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna
Expand Down Expand Up @@ -372,6 +373,9 @@ def __contains__(self, key: Any) -> bool:

@doc(Index.astype)
def astype(self, dtype, copy=True):
if dtype is not None:
dtype = pandas_dtype(dtype)

if is_interval_dtype(dtype):
from pandas import IntervalIndex

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2714,7 +2714,7 @@ def make_block(values, placement, klass=None, ndim=None, dtype=None):
dtype = dtype or values.dtype
klass = get_block_type(values, dtype)

elif klass is DatetimeTZBlock and not is_datetime64tz_dtype(values):
elif klass is DatetimeTZBlock and not is_datetime64tz_dtype(values.dtype):
# TODO: This is no longer hit internally; does it need to be retained
# for e.g. pyarrow?
values = DatetimeArray._simple_new(values, dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/internals/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _get_empty_dtype_and_na(join_units):
if len(join_units) == 1:
blk = join_units[0].block
if blk is None:
return np.float64, np.nan
return np.dtype(np.float64), np.nan

if _is_uniform_reindex(join_units):
# FIXME: integrate property
Expand Down Expand Up @@ -424,7 +424,7 @@ def _get_empty_dtype_and_na(join_units):
return g, g.type(np.nan)
elif is_numeric_dtype(g):
if has_none_blocks:
return np.float64, np.nan
return np.dtype(np.float64), np.nan
else:
return g, None

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,12 @@ def nanvar(values, axis=None, skipna=True, ddof=1, mask=None):
values = extract_array(values, extract_numpy=True)
dtype = values.dtype
mask = _maybe_get_mask(values, skipna, mask)
if is_any_int_dtype(values):
if is_any_int_dtype(dtype):
values = values.astype("f8")
if mask is not None:
values[mask] = np.nan

if is_float_dtype(values):
if is_float_dtype(values.dtype):
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype)
else:
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof)
Expand Down
Loading