Skip to content

Commit 9b5f2f5

Browse files
committed
BUG: Reverse operators on integer-NA series and numpy scalars resulting in object dtype (pandas-dev#22024)
1 parent 7390963 commit 9b5f2f5

File tree

1 file changed

+28
-49
lines changed

1 file changed

+28
-49
lines changed

pandas/core/ops.py

+28-49
Original file line numberDiff line numberDiff line change
@@ -1088,35 +1088,43 @@ def dispatch_to_extension_op(op, left, right):
10881088
# we need to listify to avoid ndarray, or non-same-type extension array
10891089
# dispatching
10901090

1091+
new_type, left_type, right_type = None, None, None
10911092
if is_extension_array_dtype(left):
1092-
1093+
left_type = left.dtype.type
10931094
new_left = left.values
10941095
if isinstance(right, np.ndarray):
10951096

10961097
# handle numpy scalars, this is a PITA
10971098
# TODO(jreback)
10981099
new_right = lib.item_from_zerodim(right)
1100+
right_type = new_right.dtype
10991101
if is_scalar(new_right):
11001102
new_right = [new_right]
11011103
new_right = list(new_right)
11021104
elif is_extension_array_dtype(right) and type(left) != type(right):
1105+
right_type = new_right.dtype.type
11031106
new_right = list(new_right)
11041107
else:
11051108
new_right = right
1106-
1109+
right_type = type(right)
11071110
else:
11081111

11091112
new_left = list(left.values)
11101113
new_right = right
11111114

11121115
res_values = op(new_left, new_right)
11131116
res_name = get_op_result_name(left, right)
1114-
1117+
if right_type and left_type:
1118+
new_type = find_common_type([right_type, left_type])
11151119
if op.__name__ == 'divmod':
11161120
return _construct_divmod_result(
11171121
left, res_values, left.index, res_name)
11181122

1119-
return _construct_result(left, res_values, left.index, res_name)
1123+
result = _construct_result(left, res_values, left.index, res_name)
1124+
if result.dtype == "object":
1125+
result = _construct_result(left, res_values, left.index, res_name,
1126+
new_type)
1127+
return result
11201128

11211129

11221130
def _arith_method_SERIES(cls, op, special):
@@ -1143,7 +1151,6 @@ def na_op(x, y):
11431151
result[mask] = op(x[mask], com.values_from_object(y[mask]))
11441152
else:
11451153
assert isinstance(x, np.ndarray)
1146-
assert is_scalar(y)
11471154
result = np.empty(len(x), dtype=x.dtype)
11481155
mask = notna(x)
11491156
result[mask] = op(x[mask], y)
@@ -1190,7 +1197,6 @@ def wrapper(left, right):
11901197

11911198
elif (is_extension_array_dtype(left) or
11921199
is_extension_array_dtype(right)):
1193-
# TODO: should this include `not is_scalar(right)`?
11941200
return dispatch_to_extension_op(op, left, right)
11951201

11961202
elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
@@ -1280,11 +1286,13 @@ def na_op(x, y):
12801286
# should have guarantess on what x, y can be type-wise
12811287
# Extension Dtypes are not called here
12821288

1283-
# Checking that cases that were once handled here are no longer
1284-
# reachable.
1285-
assert not (is_categorical_dtype(y) and not is_scalar(y))
1289+
# dispatch to the categorical if we have a categorical
1290+
# in either operand
1291+
if is_categorical_dtype(y) and not is_scalar(y):
1292+
# The `not is_scalar(y)` check excludes the string "category"
1293+
return op(y, x)
12861294

1287-
if is_object_dtype(x.dtype):
1295+
elif is_object_dtype(x.dtype):
12881296
result = _comp_method_OBJECT_ARRAY(op, x, y)
12891297

12901298
elif is_datetimelike_v_numeric(x, y):
@@ -1342,7 +1350,7 @@ def wrapper(self, other, axis=None):
13421350
return self._constructor(res_values, index=self.index,
13431351
name=res_name)
13441352

1345-
elif is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
1353+
if is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
13461354
# Dispatch to DatetimeIndex to ensure identical
13471355
# Series/Index behavior
13481356
if (isinstance(other, datetime.date) and
@@ -1384,9 +1392,8 @@ def wrapper(self, other, axis=None):
13841392
name=res_name)
13851393

13861394
elif (is_extension_array_dtype(self) or
1387-
(is_extension_array_dtype(other) and not is_scalar(other))):
1388-
# Note: the `not is_scalar(other)` condition rules out
1389-
# e.g. other == "category"
1395+
(is_extension_array_dtype(other) and
1396+
not is_scalar(other))):
13901397
return dispatch_to_extension_op(op, self, other)
13911398

13921399
elif isinstance(other, ABCSeries):
@@ -1409,6 +1416,13 @@ def wrapper(self, other, axis=None):
14091416
# is not.
14101417
return result.__finalize__(self).rename(res_name)
14111418

1419+
elif isinstance(other, pd.Categorical):
1420+
# ordering of checks matters; by this point we know
1421+
# that not is_categorical_dtype(self)
1422+
res_values = op(self.values, other)
1423+
return self._constructor(res_values, index=self.index,
1424+
name=res_name)
1425+
14121426
elif is_scalar(other) and isna(other):
14131427
# numpy does not like comparisons vs None
14141428
if op is operator.ne:
@@ -1538,41 +1552,6 @@ def flex_wrapper(self, other, level=None, fill_value=None, axis=0):
15381552
# -----------------------------------------------------------------------------
15391553
# DataFrame
15401554

1541-
def dispatch_to_series(left, right, func):
1542-
"""
1543-
Evaluate the frame operation func(left, right) by evaluating
1544-
column-by-column, dispatching to the Series implementation.
1545-
1546-
Parameters
1547-
----------
1548-
left : DataFrame
1549-
right : scalar or DataFrame
1550-
func : arithmetic or comparison operator
1551-
1552-
Returns
1553-
-------
1554-
DataFrame
1555-
"""
1556-
# Note: we use iloc to access columns for compat with cases
1557-
# with non-unique columns.
1558-
if lib.is_scalar(right):
1559-
new_data = {i: func(left.iloc[:, i], right)
1560-
for i in range(len(left.columns))}
1561-
elif isinstance(right, ABCDataFrame):
1562-
assert right._indexed_same(left)
1563-
new_data = {i: func(left.iloc[:, i], right.iloc[:, i])
1564-
for i in range(len(left.columns))}
1565-
else:
1566-
# Remaining cases have less-obvious dispatch rules
1567-
raise NotImplementedError
1568-
1569-
result = left._constructor(new_data, index=left.index, copy=False)
1570-
# Pin columns instead of passing to constructor for compat with
1571-
# non-unique columns case
1572-
result.columns = left.columns
1573-
return result
1574-
1575-
15761555
def _combine_series_frame(self, other, func, fill_value=None, axis=None,
15771556
level=None, try_cast=True):
15781557
"""

0 commit comments

Comments
 (0)