4
4
This is not a public API.
5
5
"""
6
6
import operator
7
- from typing import TYPE_CHECKING , Optional , Set
7
+ from typing import TYPE_CHECKING , Optional , Set , Type
8
8
9
9
import numpy as np
10
10
21
21
from pandas .core .ops .array_ops import (
22
22
arithmetic_op ,
23
23
comparison_op ,
24
- define_na_arithmetic_op ,
25
24
get_array_op ,
26
25
logical_op ,
27
26
)
28
27
from pandas .core .ops .array_ops import comp_method_OBJECT_ARRAY # noqa:F401
29
28
from pandas .core .ops .common import unpack_zerodim_and_defer
30
- from pandas .core .ops .dispatch import should_series_dispatch
31
29
from pandas .core .ops .docstrings import (
32
30
_arith_doc_FRAME ,
33
31
_flex_comp_doc_FRAME ,
@@ -154,7 +152,7 @@ def _maybe_match_name(a, b):
154
152
# -----------------------------------------------------------------------------
155
153
156
154
157
- def _get_frame_op_default_axis (name ) :
155
+ def _get_frame_op_default_axis (name : str ) -> Optional [ str ] :
158
156
"""
159
157
Only DataFrame cares about default_axis, specifically:
160
158
special methods have default_axis=None and flex methods
@@ -277,7 +275,11 @@ def dispatch_to_series(left, right, func, axis=None):
277
275
return type (left )(bm )
278
276
279
277
elif isinstance (right , ABCDataFrame ):
280
- assert right ._indexed_same (left )
278
+ assert left .index .equals (right .index )
279
+ assert left .columns .equals (right .columns )
280
+ # TODO: The previous assertion `assert right._indexed_same(left)`
281
+ # fails in cases with empty columns reached via
282
+ # _frame_arith_method_with_reindex
281
283
282
284
array_op = get_array_op (func )
283
285
bm = left ._mgr .operate_blockwise (right ._mgr , array_op )
@@ -345,6 +347,7 @@ def _arith_method_SERIES(cls, op, special):
345
347
Wrapper function for Series arithmetic operations, to avoid
346
348
code duplication.
347
349
"""
350
+ assert special # non-special uses _flex_method_SERIES
348
351
op_name = _get_op_name (op , special )
349
352
350
353
@unpack_zerodim_and_defer (op_name )
@@ -368,6 +371,7 @@ def _comp_method_SERIES(cls, op, special):
368
371
Wrapper function for Series arithmetic operations, to avoid
369
372
code duplication.
370
373
"""
374
+ assert special # non-special uses _flex_method_SERIES
371
375
op_name = _get_op_name (op , special )
372
376
373
377
@unpack_zerodim_and_defer (op_name )
@@ -394,6 +398,7 @@ def _bool_method_SERIES(cls, op, special):
394
398
Wrapper function for Series arithmetic operations, to avoid
395
399
code duplication.
396
400
"""
401
+ assert special # non-special uses _flex_method_SERIES
397
402
op_name = _get_op_name (op , special )
398
403
399
404
@unpack_zerodim_and_defer (op_name )
@@ -412,6 +417,7 @@ def wrapper(self, other):
412
417
413
418
414
419
def _flex_method_SERIES (cls , op , special ):
420
+ assert not special # "special" also means "not flex"
415
421
name = _get_op_name (op , special )
416
422
doc = _make_flex_doc (name , "series" )
417
423
@@ -574,7 +580,7 @@ def to_series(right):
574
580
575
581
576
582
def _should_reindex_frame_op (
577
- left : "DataFrame" , right , op , axis , default_axis : int , fill_value , level
583
+ left : "DataFrame" , right , op , axis , default_axis , fill_value , level
578
584
) -> bool :
579
585
"""
580
586
Check if this is an operation between DataFrames that will need to reindex.
@@ -629,11 +635,12 @@ def _frame_arith_method_with_reindex(
629
635
return result .reindex (join_columns , axis = 1 )
630
636
631
637
632
- def _arith_method_FRAME (cls , op , special ):
638
+ def _arith_method_FRAME (cls : Type ["DataFrame" ], op , special : bool ):
639
+ # This is the only function where `special` can be either True or False
633
640
op_name = _get_op_name (op , special )
634
641
default_axis = _get_frame_op_default_axis (op_name )
635
642
636
- na_op = define_na_arithmetic_op (op )
643
+ na_op = get_array_op (op )
637
644
is_logical = op .__name__ .strip ("_" ).lstrip ("_" ) in ["and" , "or" , "xor" ]
638
645
639
646
if op_name in _op_descriptions :
@@ -650,18 +657,19 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
650
657
):
651
658
return _frame_arith_method_with_reindex (self , other , op )
652
659
660
+ # TODO: why are we passing flex=True instead of flex=not special?
661
+ # 15 tests fail if we pass flex=not special instead
653
662
self , other = _align_method_FRAME (self , other , axis , flex = True , level = level )
654
663
655
664
if isinstance (other , ABCDataFrame ):
656
665
# Another DataFrame
657
- pass_op = op if should_series_dispatch (self , other , op ) else na_op
658
- pass_op = pass_op if not is_logical else op
659
-
660
- new_data = self ._combine_frame (other , pass_op , fill_value )
666
+ new_data = self ._combine_frame (other , na_op , fill_value )
661
667
662
668
elif isinstance (other , ABCSeries ):
663
669
# For these values of `axis`, we end up dispatching to Series op,
664
670
# so do not want the masked op.
671
+ # TODO: the above comment is no longer accurate since we now
672
+ # operate blockwise if other._values is an ndarray
665
673
pass_op = op if axis in [0 , "columns" , None ] else na_op
666
674
pass_op = pass_op if not is_logical else op
667
675
@@ -684,9 +692,11 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
684
692
return f
685
693
686
694
687
- def _flex_comp_method_FRAME (cls , op , special ):
695
+ def _flex_comp_method_FRAME (cls : Type ["DataFrame" ], op , special : bool ):
696
+ assert not special # "special" also means "not flex"
688
697
op_name = _get_op_name (op , special )
689
698
default_axis = _get_frame_op_default_axis (op_name )
699
+ assert default_axis == "columns" , default_axis # because we are not "special"
690
700
691
701
doc = _flex_comp_doc_FRAME .format (
692
702
op_name = op_name , desc = _op_descriptions [op_name ]["desc" ]
@@ -715,7 +725,8 @@ def f(self, other, axis=default_axis, level=None):
715
725
return f
716
726
717
727
718
- def _comp_method_FRAME (cls , op , special ):
728
+ def _comp_method_FRAME (cls : Type ["DataFrame" ], op , special : bool ):
729
+ assert special # "special" also means "not flex"
719
730
op_name = _get_op_name (op , special )
720
731
721
732
@Appender (f"Wrapper for comparison method { op_name } " )
0 commit comments