@@ -1088,35 +1088,43 @@ def dispatch_to_extension_op(op, left, right):
1088
1088
# we need to listify to avoid ndarray, or non-same-type extension array
1089
1089
# dispatching
1090
1090
1091
+ new_type , left_type , right_type = None , None , None
1091
1092
if is_extension_array_dtype (left ):
1092
-
1093
+ left_type = left . dtype . type
1093
1094
new_left = left .values
1094
1095
if isinstance (right , np .ndarray ):
1095
1096
1096
1097
# handle numpy scalars, this is a PITA
1097
1098
# TODO(jreback)
1098
1099
new_right = lib .item_from_zerodim (right )
1100
+ right_type = new_right .dtype
1099
1101
if is_scalar (new_right ):
1100
1102
new_right = [new_right ]
1101
1103
new_right = list (new_right )
1102
1104
elif is_extension_array_dtype (right ) and type (left ) != type (right ):
1105
+ right_type = new_right .dtype .type
1103
1106
new_right = list (new_right )
1104
1107
else :
1105
1108
new_right = right
1106
-
1109
+ right_type = type ( right )
1107
1110
else :
1108
1111
1109
1112
new_left = list (left .values )
1110
1113
new_right = right
1111
1114
1112
1115
res_values = op (new_left , new_right )
1113
1116
res_name = get_op_result_name (left , right )
1114
-
1117
+ if right_type and left_type :
1118
+ new_type = find_common_type ([right_type , left_type ])
1115
1119
if op .__name__ == 'divmod' :
1116
1120
return _construct_divmod_result (
1117
1121
left , res_values , left .index , res_name )
1118
1122
1119
- return _construct_result (left , res_values , left .index , res_name )
1123
+ result = _construct_result (left , res_values , left .index , res_name )
1124
+ if result .dtype == "object" :
1125
+ result = _construct_result (left , res_values , left .index , res_name ,
1126
+ new_type )
1127
+ return result
1120
1128
1121
1129
1122
1130
def _arith_method_SERIES (cls , op , special ):
@@ -1143,7 +1151,6 @@ def na_op(x, y):
1143
1151
result [mask ] = op (x [mask ], com .values_from_object (y [mask ]))
1144
1152
else :
1145
1153
assert isinstance (x , np .ndarray )
1146
- assert is_scalar (y )
1147
1154
result = np .empty (len (x ), dtype = x .dtype )
1148
1155
mask = notna (x )
1149
1156
result [mask ] = op (x [mask ], y )
@@ -1190,7 +1197,6 @@ def wrapper(left, right):
1190
1197
1191
1198
elif (is_extension_array_dtype (left ) or
1192
1199
is_extension_array_dtype (right )):
1193
- # TODO: should this include `not is_scalar(right)`?
1194
1200
return dispatch_to_extension_op (op , left , right )
1195
1201
1196
1202
elif is_datetime64_dtype (left ) or is_datetime64tz_dtype (left ):
@@ -1280,11 +1286,13 @@ def na_op(x, y):
1280
1286
# should have guarantess on what x, y can be type-wise
1281
1287
# Extension Dtypes are not called here
1282
1288
1283
- # Checking that cases that were once handled here are no longer
1284
- # reachable.
1285
- assert not (is_categorical_dtype (y ) and not is_scalar (y ))
1289
+ # dispatch to the categorical if we have a categorical
1290
+ # in either operand
1291
+ if is_categorical_dtype (y ) and not is_scalar (y ):
1292
+ # The `not is_scalar(y)` check excludes the string "category"
1293
+ return op (y , x )
1286
1294
1287
- if is_object_dtype (x .dtype ):
1295
+ elif is_object_dtype (x .dtype ):
1288
1296
result = _comp_method_OBJECT_ARRAY (op , x , y )
1289
1297
1290
1298
elif is_datetimelike_v_numeric (x , y ):
@@ -1342,7 +1350,7 @@ def wrapper(self, other, axis=None):
1342
1350
return self ._constructor (res_values , index = self .index ,
1343
1351
name = res_name )
1344
1352
1345
- elif is_datetime64_dtype (self ) or is_datetime64tz_dtype (self ):
1353
+ if is_datetime64_dtype (self ) or is_datetime64tz_dtype (self ):
1346
1354
# Dispatch to DatetimeIndex to ensure identical
1347
1355
# Series/Index behavior
1348
1356
if (isinstance (other , datetime .date ) and
@@ -1384,9 +1392,8 @@ def wrapper(self, other, axis=None):
1384
1392
name = res_name )
1385
1393
1386
1394
elif (is_extension_array_dtype (self ) or
1387
- (is_extension_array_dtype (other ) and not is_scalar (other ))):
1388
- # Note: the `not is_scalar(other)` condition rules out
1389
- # e.g. other == "category"
1395
+ (is_extension_array_dtype (other ) and
1396
+ not is_scalar (other ))):
1390
1397
return dispatch_to_extension_op (op , self , other )
1391
1398
1392
1399
elif isinstance (other , ABCSeries ):
@@ -1409,6 +1416,13 @@ def wrapper(self, other, axis=None):
1409
1416
# is not.
1410
1417
return result .__finalize__ (self ).rename (res_name )
1411
1418
1419
+ elif isinstance (other , pd .Categorical ):
1420
+ # ordering of checks matters; by this point we know
1421
+ # that not is_categorical_dtype(self)
1422
+ res_values = op (self .values , other )
1423
+ return self ._constructor (res_values , index = self .index ,
1424
+ name = res_name )
1425
+
1412
1426
elif is_scalar (other ) and isna (other ):
1413
1427
# numpy does not like comparisons vs None
1414
1428
if op is operator .ne :
@@ -1538,41 +1552,6 @@ def flex_wrapper(self, other, level=None, fill_value=None, axis=0):
1538
1552
# -----------------------------------------------------------------------------
1539
1553
# DataFrame
1540
1554
1541
- def dispatch_to_series (left , right , func ):
1542
- """
1543
- Evaluate the frame operation func(left, right) by evaluating
1544
- column-by-column, dispatching to the Series implementation.
1545
-
1546
- Parameters
1547
- ----------
1548
- left : DataFrame
1549
- right : scalar or DataFrame
1550
- func : arithmetic or comparison operator
1551
-
1552
- Returns
1553
- -------
1554
- DataFrame
1555
- """
1556
- # Note: we use iloc to access columns for compat with cases
1557
- # with non-unique columns.
1558
- if lib .is_scalar (right ):
1559
- new_data = {i : func (left .iloc [:, i ], right )
1560
- for i in range (len (left .columns ))}
1561
- elif isinstance (right , ABCDataFrame ):
1562
- assert right ._indexed_same (left )
1563
- new_data = {i : func (left .iloc [:, i ], right .iloc [:, i ])
1564
- for i in range (len (left .columns ))}
1565
- else :
1566
- # Remaining cases have less-obvious dispatch rules
1567
- raise NotImplementedError
1568
-
1569
- result = left ._constructor (new_data , index = left .index , copy = False )
1570
- # Pin columns instead of passing to constructor for compat with
1571
- # non-unique columns case
1572
- result .columns = left .columns
1573
- return result
1574
-
1575
-
1576
1555
def _combine_series_frame (self , other , func , fill_value = None , axis = None ,
1577
1556
level = None , try_cast = True ):
1578
1557
"""
0 commit comments