@@ -281,42 +281,33 @@ def test_compare_scalar(self, data, comparison_op):
281
281
282
282
@pytest .mark .parametrize ("na_action" , [None , "ignore" ])
283
283
def test_map (self , data_missing , na_action ):
284
- if data_missing .dtype .kind in "mM" :
285
- mapped = data_missing .map (lambda x : x , na_action = na_action )
286
- result = pd .Series (mapped )
287
- expected = pd .Series (data_missing .to_numpy ())
284
+ result = data_missing .map (lambda x : x , na_action = na_action )
288
285
289
- orig_dtype = expected .dtype
286
+ if data_missing .dtype == "float32[pyarrow]" :
287
+ expected = data_missing .to_numpy (dtype = "float64" , na_value = np .nan )
288
+ tm .assert_numpy_array_equal (result , expected )
290
289
291
- if result .dtype == "float64" and (
292
- is_datetime64_any_dtype (orig_dtype )
293
- or is_timedelta64_dtype (orig_dtype )
294
- or isinstance (orig_dtype , pd .DatetimeTZDtype )
295
- ):
296
- result = result .astype (orig_dtype )
290
+ elif data_missing .dtype .kind in "mM" :
291
+ expected = pd .Series (data_missing .to_numpy ())
297
292
293
+ orig_dtype = expected .dtype
298
294
if isinstance (orig_dtype , pd .DatetimeTZDtype ):
299
295
pass
300
296
elif is_datetime64_any_dtype (orig_dtype ):
301
297
result = result .astype ("datetime64[ns]" ).astype ("int64" )
302
298
expected = expected .astype ("datetime64[ns]" ).astype ("int64" )
303
- result = pd .Series (result )
304
- expected = pd .Series (expected )
305
299
elif is_timedelta64_dtype (orig_dtype ):
306
300
result = result .astype ("timedelta64[ns]" )
307
301
expected = expected .astype ("timedelta64[ns]" )
308
302
303
+ result = pd .Series (result )
304
+ expected = pd .Series (expected )
309
305
tm .assert_series_equal (
310
306
result , expected , check_dtype = False , check_exact = False
311
307
)
312
308
313
309
else :
314
- result = data_missing .map (lambda x : x , na_action = na_action )
315
- if data_missing .dtype == "float32[pyarrow]" :
316
- # map roundtrips through objects, which converts to float64
317
- expected = data_missing .to_numpy (dtype = "float64" , na_value = np .nan )
318
- else :
319
- expected = data_missing .to_numpy ()
310
+ expected = data_missing .to_numpy ()
320
311
tm .assert_numpy_array_equal (result , expected )
321
312
322
313
def test_astype_str (self , data , request , using_infer_string ):
0 commit comments