@@ -160,28 +160,22 @@ def test_where_unsafe_float(float_dtype):
160
160
assert_series_equal (s , expected )
161
161
162
162
163
- @pytest .mark .parametrize ("dtype" , [np .int64 , np .float64 ])
164
- def test_where_unsafe_upcast (dtype ):
165
- s = Series (np .arange (10 ), dtype = dtype )
166
- values = [2.5 , 3.5 , 4.5 , 5.5 , 6.5 ]
167
-
168
- mask = s < 5
169
- expected = Series (values + lrange (5 , 10 ), dtype = "float64" )
170
-
171
- s [mask ] = values
172
- assert_series_equal (s , expected )
173
-
174
-
175
- @pytest .mark .parametrize ("dtype" , [
176
- np .int8 , np .int16 , np .int32 , np .float32
163
+ @pytest .mark .parametrize ("dtype,expected_dtype" , [
164
+ (np .int8 , np .float64 ),
165
+ (np .int16 , np .float64 ),
166
+ (np .int32 , np .float64 ),
167
+ (np .int64 , np .float64 ),
168
+ (np .float32 , np .float32 ),
169
+ (np .float64 , np .float64 )
177
170
])
178
- def test_where_upcast (dtype ):
171
+ def test_where_unsafe_upcast (dtype , expected_dtype ):
179
172
# see gh-9743
180
173
s = Series (np .arange (10 ), dtype = dtype )
181
- mask = s < 5
182
-
183
174
values = [2.5 , 3.5 , 4.5 , 5.5 , 6.5 ]
175
+ mask = s < 5
176
+ expected = Series (values + lrange (5 , 10 ), dtype = expected_dtype )
184
177
s [mask ] = values
178
+ assert_series_equal (s , expected )
185
179
186
180
187
181
def test_where_unsafe ():
0 commit comments