1
1
import decimal
2
- import math
3
2
import operator
4
3
5
4
import numpy as np
@@ -70,54 +69,7 @@ def data_for_grouping():
70
69
return DecimalArray ([b , b , na , na , a , a , b , c ])
71
70
72
71
73
- class BaseDecimal :
74
- @classmethod
75
- def assert_series_equal (cls , left , right , * args , ** kwargs ):
76
- def convert (x ):
77
- # need to convert array([Decimal(NaN)], dtype='object') to np.NaN
78
- # because Series[object].isnan doesn't recognize decimal(NaN) as
79
- # NA.
80
- try :
81
- return math .isnan (x )
82
- except TypeError :
83
- return False
84
-
85
- if left .dtype == "object" :
86
- left_na = left .apply (convert )
87
- else :
88
- left_na = left .isna ()
89
- if right .dtype == "object" :
90
- right_na = right .apply (convert )
91
- else :
92
- right_na = right .isna ()
93
-
94
- tm .assert_series_equal (left_na , right_na )
95
- return tm .assert_series_equal (left [~ left_na ], right [~ right_na ], * args , ** kwargs )
96
-
97
- @classmethod
98
- def assert_frame_equal (cls , left , right , * args , ** kwargs ):
99
- # TODO(EA): select_dtypes
100
- tm .assert_index_equal (
101
- left .columns ,
102
- right .columns ,
103
- exact = kwargs .get ("check_column_type" , "equiv" ),
104
- check_names = kwargs .get ("check_names" , True ),
105
- check_exact = kwargs .get ("check_exact" , False ),
106
- check_categorical = kwargs .get ("check_categorical" , True ),
107
- obj = f"{ kwargs .get ('obj' , 'DataFrame' )} .columns" ,
108
- )
109
-
110
- decimals = (left .dtypes == "decimal" ).index
111
-
112
- for col in decimals :
113
- cls .assert_series_equal (left [col ], right [col ], * args , ** kwargs )
114
-
115
- left = left .drop (columns = decimals )
116
- right = right .drop (columns = decimals )
117
- tm .assert_frame_equal (left , right , * args , ** kwargs )
118
-
119
-
120
- class TestDtype (BaseDecimal , base .BaseDtypeTests ):
72
+ class TestDtype (base .BaseDtypeTests ):
121
73
def test_hashable (self , dtype ):
122
74
pass
123
75
@@ -129,27 +81,27 @@ def test_infer_dtype(self, data, data_missing, skipna):
129
81
assert infer_dtype (data_missing , skipna = skipna ) == "unknown-array"
130
82
131
83
132
- class TestInterface (BaseDecimal , base .BaseInterfaceTests ):
84
+ class TestInterface (base .BaseInterfaceTests ):
133
85
pass
134
86
135
87
136
- class TestConstructors (BaseDecimal , base .BaseConstructorsTests ):
88
+ class TestConstructors (base .BaseConstructorsTests ):
137
89
pass
138
90
139
91
140
- class TestReshaping (BaseDecimal , base .BaseReshapingTests ):
92
+ class TestReshaping (base .BaseReshapingTests ):
141
93
pass
142
94
143
95
144
- class TestGetitem (BaseDecimal , base .BaseGetitemTests ):
96
+ class TestGetitem (base .BaseGetitemTests ):
145
97
def test_take_na_value_other_decimal (self ):
146
98
arr = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("2.0" )])
147
99
result = arr .take ([0 , - 1 ], allow_fill = True , fill_value = decimal .Decimal ("-1.0" ))
148
100
expected = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("-1.0" )])
149
101
self .assert_extension_array_equal (result , expected )
150
102
151
103
152
- class TestMissing (BaseDecimal , base .BaseMissingTests ):
104
+ class TestMissing (base .BaseMissingTests ):
153
105
pass
154
106
155
107
@@ -175,7 +127,7 @@ class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
175
127
pass
176
128
177
129
178
- class TestMethods (BaseDecimal , base .BaseMethodsTests ):
130
+ class TestMethods (base .BaseMethodsTests ):
179
131
@pytest .mark .parametrize ("dropna" , [True , False ])
180
132
def test_value_counts (self , all_data , dropna , request ):
181
133
all_data = all_data [:10 ]
@@ -200,20 +152,20 @@ def test_value_counts_with_normalize(self, data):
200
152
return super ().test_value_counts_with_normalize (data )
201
153
202
154
203
- class TestCasting (BaseDecimal , base .BaseCastingTests ):
155
+ class TestCasting (base .BaseCastingTests ):
204
156
pass
205
157
206
158
207
- class TestGroupby (BaseDecimal , base .BaseGroupbyTests ):
159
+ class TestGroupby (base .BaseGroupbyTests ):
208
160
def test_groupby_agg_extension (self , data_for_grouping ):
209
161
super ().test_groupby_agg_extension (data_for_grouping )
210
162
211
163
212
- class TestSetitem (BaseDecimal , base .BaseSetitemTests ):
164
+ class TestSetitem (base .BaseSetitemTests ):
213
165
pass
214
166
215
167
216
- class TestPrinting (BaseDecimal , base .BasePrintingTests ):
168
+ class TestPrinting (base .BasePrintingTests ):
217
169
def test_series_repr (self , data ):
218
170
# Overriding this base test to explicitly test that
219
171
# the custom _formatter is used
@@ -282,7 +234,7 @@ def test_astype_dispatches(frame):
282
234
assert result .dtype .context .prec == ctx .prec
283
235
284
236
285
- class TestArithmeticOps (BaseDecimal , base .BaseArithmeticOpsTests ):
237
+ class TestArithmeticOps (base .BaseArithmeticOpsTests ):
286
238
def check_opname (self , s , op_name , other , exc = None ):
287
239
super ().check_opname (s , op_name , other , exc = None )
288
240
@@ -313,7 +265,7 @@ def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
313
265
super ()._check_divmod_op (s , op , other , exc = None )
314
266
315
267
316
- class TestComparisonOps (BaseDecimal , base .BaseComparisonOpsTests ):
268
+ class TestComparisonOps (base .BaseComparisonOpsTests ):
317
269
def test_compare_scalar (self , data , all_compare_operators ):
318
270
op_name = all_compare_operators
319
271
s = pd .Series (data )
0 commit comments