1
1
from __future__ import annotations
2
2
3
- from typing import cast
3
+ from typing import (
4
+ Literal ,
5
+ cast ,
6
+ )
4
7
import warnings
5
8
6
9
import numpy as np
10
13
no_default ,
11
14
)
12
15
from pandas ._libs .missing import is_matching_na
16
+ from pandas ._libs .sparse import SparseIndex
13
17
import pandas ._libs .testing as _testing
14
18
from pandas .util ._exceptions import find_stack_level
15
19
61
65
def assert_almost_equal (
62
66
left ,
63
67
right ,
64
- check_dtype : bool | str = "equiv" ,
68
+ check_dtype : bool | Literal [ "equiv" ] = "equiv" ,
65
69
check_less_precise : bool | int | NoDefault = no_default ,
66
70
rtol : float = 1.0e-5 ,
67
71
atol : float = 1.0e-8 ,
@@ -164,9 +168,8 @@ def assert_almost_equal(
164
168
assert_class_equal (left , right , obj = obj )
165
169
166
170
# if we have "equiv", this becomes True
167
- check_dtype = bool (check_dtype )
168
171
_testing .assert_almost_equal (
169
- left , right , check_dtype = check_dtype , rtol = rtol , atol = atol , ** kwargs
172
+ left , right , check_dtype = bool ( check_dtype ) , rtol = rtol , atol = atol , ** kwargs
170
173
)
171
174
172
175
@@ -676,7 +679,7 @@ def assert_numpy_array_equal(
676
679
left ,
677
680
right ,
678
681
strict_nan = False ,
679
- check_dtype = True ,
682
+ check_dtype : bool | Literal [ "equiv" ] = True ,
680
683
err_msg = None ,
681
684
check_same = None ,
682
685
obj = "numpy array" ,
@@ -755,7 +758,7 @@ def _raise(left, right, err_msg):
755
758
def assert_extension_array_equal (
756
759
left ,
757
760
right ,
758
- check_dtype = True ,
761
+ check_dtype : bool | Literal [ "equiv" ] = True ,
759
762
index_values = None ,
760
763
check_less_precise = no_default ,
761
764
check_exact = False ,
@@ -848,7 +851,7 @@ def assert_extension_array_equal(
848
851
_testing .assert_almost_equal (
849
852
left_valid ,
850
853
right_valid ,
851
- check_dtype = check_dtype ,
854
+ check_dtype = bool ( check_dtype ) ,
852
855
rtol = rtol ,
853
856
atol = atol ,
854
857
obj = "ExtensionArray" ,
@@ -860,7 +863,7 @@ def assert_extension_array_equal(
860
863
def assert_series_equal (
861
864
left ,
862
865
right ,
863
- check_dtype = True ,
866
+ check_dtype : bool | Literal [ "equiv" ] = True ,
864
867
check_index_type = "equiv" ,
865
868
check_series_type = True ,
866
869
check_less_precise = no_default ,
@@ -1054,7 +1057,7 @@ def assert_series_equal(
1054
1057
right ._values ,
1055
1058
rtol = rtol ,
1056
1059
atol = atol ,
1057
- check_dtype = check_dtype ,
1060
+ check_dtype = bool ( check_dtype ) ,
1058
1061
obj = str (obj ),
1059
1062
index_values = np .asarray (left .index ),
1060
1063
)
@@ -1090,7 +1093,7 @@ def assert_series_equal(
1090
1093
right ._values ,
1091
1094
rtol = rtol ,
1092
1095
atol = atol ,
1093
- check_dtype = check_dtype ,
1096
+ check_dtype = bool ( check_dtype ) ,
1094
1097
obj = str (obj ),
1095
1098
index_values = np .asarray (left .index ),
1096
1099
)
@@ -1115,7 +1118,7 @@ def assert_series_equal(
1115
1118
def assert_frame_equal (
1116
1119
left ,
1117
1120
right ,
1118
- check_dtype = True ,
1121
+ check_dtype : bool | Literal [ "equiv" ] = True ,
1119
1122
check_index_type = "equiv" ,
1120
1123
check_column_type = "equiv" ,
1121
1124
check_frame_type = True ,
@@ -1393,8 +1396,8 @@ def assert_sp_array_equal(left, right):
1393
1396
assert_numpy_array_equal (left .sp_values , right .sp_values )
1394
1397
1395
1398
# SparseIndex comparison
1396
- assert isinstance (left .sp_index , pd . _libs . sparse . SparseIndex )
1397
- assert isinstance (right .sp_index , pd . _libs . sparse . SparseIndex )
1399
+ assert isinstance (left .sp_index , SparseIndex )
1400
+ assert isinstance (right .sp_index , SparseIndex )
1398
1401
1399
1402
left_index = left .sp_index
1400
1403
right_index = right .sp_index
0 commit comments