@@ -35,68 +35,59 @@ def na_value():
35
35
return decimal .Decimal ("NaN" )
36
36
37
37
38
- class TestDtype (base .BaseDtypeTests ):
39
- pass
38
+ class BaseDecimal (object ):
39
+ @staticmethod
40
+ def assert_series_equal (left , right , * args , ** kwargs ):
41
+ # tm.assert_series_equal doesn't handle Decimal('NaN').
42
+ # We will ensure that the NA values match, and then
43
+ # drop those values before moving on.
40
44
45
+ left_na = left .isna ()
46
+ right_na = right .isna ()
41
47
42
- class TestInterface (base .BaseInterfaceTests ):
43
- pass
48
+ tm .assert_series_equal (left_na , right_na )
49
+ tm .assert_series_equal (left [~ left_na ], right [~ right_na ],
50
+ * args , ** kwargs )
44
51
52
+ @staticmethod
53
+ def assert_frame_equal (left , right , * args , ** kwargs ):
54
+ # TODO(EA): select_dtypes
55
+ decimals = (left .dtypes == 'decimal' ).index
45
56
46
- class TestConstructors (base .BaseConstructorsTests ):
47
- pass
57
+ for col in decimals :
58
+ BaseDecimal .assert_series_equal (left [col ], right [col ],
59
+ * args , ** kwargs )
48
60
61
+ left = left .drop (columns = decimals )
62
+ right = right .drop (columns = decimals )
63
+ tm .assert_frame_equal (left , right , * args , ** kwargs )
49
64
50
- class TestReshaping (base .BaseReshapingTests ):
51
65
52
- def test_align (self , data , na_value ):
53
- # Have to override since assert_series_equal doesn't
54
- # compare Decimal(NaN) properly.
55
- a = data [:3 ]
56
- b = data [2 :5 ]
57
- r1 , r2 = pd .Series (a ).align (pd .Series (b , index = [1 , 2 , 3 ]))
66
+ class TestDtype (BaseDecimal , base .BaseDtypeTests ):
67
+ pass
58
68
59
- # NaN handling
60
- e1 = pd .Series (type (data )(list (a ) + [na_value ]))
61
- e2 = pd .Series (type (data )([na_value ] + list (b )))
62
- tm .assert_series_equal (r1 .iloc [:3 ], e1 .iloc [:3 ])
63
- assert r1 [3 ].is_nan ()
64
- assert e1 [3 ].is_nan ()
65
69
66
- tm .assert_series_equal (r2 .iloc [1 :], e2 .iloc [1 :])
67
- assert r2 [0 ].is_nan ()
68
- assert e2 [0 ].is_nan ()
70
+ class TestInterface (BaseDecimal , base .BaseInterfaceTests ):
71
+ pass
69
72
70
- def test_align_frame (self , data , na_value ):
71
- # Override for Decimal(NaN) comparison
72
- a = data [:3 ]
73
- b = data [2 :5 ]
74
- r1 , r2 = pd .DataFrame ({'A' : a }).align (
75
- pd .DataFrame ({'A' : b }, index = [1 , 2 , 3 ])
76
- )
77
73
78
- # Assumes that the ctor can take a list of scalars of the type
79
- e1 = pd .DataFrame ({'A' : type (data )(list (a ) + [na_value ])})
80
- e2 = pd .DataFrame ({'A' : type (data )([na_value ] + list (b ))})
74
+ class TestConstructors (BaseDecimal , base .BaseConstructorsTests ):
75
+ pass
81
76
82
- tm .assert_frame_equal (r1 .iloc [:3 ], e1 .iloc [:3 ])
83
- assert r1 .loc [3 , 'A' ].is_nan ()
84
- assert e1 .loc [3 , 'A' ].is_nan ()
85
77
86
- tm .assert_frame_equal (r2 .iloc [1 :], e2 .iloc [1 :])
87
- assert r2 .loc [0 , 'A' ].is_nan ()
88
- assert e2 .loc [0 , 'A' ].is_nan ()
78
+ class TestReshaping (BaseDecimal , base .BaseReshapingTests ):
79
+ pass
89
80
90
81
91
- class TestGetitem (base .BaseGetitemTests ):
82
+ class TestGetitem (BaseDecimal , base .BaseGetitemTests ):
92
83
pass
93
84
94
85
95
- class TestMissing (base .BaseMissingTests ):
86
+ class TestMissing (BaseDecimal , base .BaseMissingTests ):
96
87
pass
97
88
98
89
99
- class TestMethods (base .BaseMethodsTests ):
90
+ class TestMethods (BaseDecimal , base .BaseMethodsTests ):
100
91
@pytest .mark .parametrize ('dropna' , [True , False ])
101
92
@pytest .mark .xfail (reason = "value_counts not implemented yet." )
102
93
def test_value_counts (self , all_data , dropna ):
@@ -112,7 +103,7 @@ def test_value_counts(self, all_data, dropna):
112
103
tm .assert_series_equal (result , expected )
113
104
114
105
115
- class TestCasting (base .BaseCastingTests ):
106
+ class TestCasting (BaseDecimal , base .BaseCastingTests ):
116
107
pass
117
108
118
109
0 commit comments