@@ -442,8 +442,9 @@ def test_astype_str(self):
442
442
expected = DataFrame (['1.12345678901' ])
443
443
assert_frame_equal (result , expected )
444
444
445
- def test_astype_dict (self ):
446
- # GH7271
445
+ @pytest .mark .parametrize ("dtype_class" , [dict , Series ])
446
+ def test_astype_dict_like (self , dtype_class ):
447
+ # GH7271 & GH16717
447
448
a = Series (date_range ('2010-01-04' , periods = 5 ))
448
449
b = Series (range (5 ))
449
450
c = Series ([0.0 , 0.2 , 0.4 , 0.6 , 0.8 ])
@@ -452,7 +453,8 @@ def test_astype_dict(self):
452
453
original = df .copy (deep = True )
453
454
454
455
# change type of a subset of columns
455
- result = df .astype ({'b' : 'str' , 'd' : 'float32' })
456
+ dt1 = dtype_class ({'b' : 'str' , 'd' : 'float32' })
457
+ result = df .astype (dt1 )
456
458
expected = DataFrame ({
457
459
'a' : a ,
458
460
'b' : Series (['0' , '1' , '2' , '3' , '4' ]),
@@ -461,7 +463,8 @@ def test_astype_dict(self):
461
463
assert_frame_equal (result , expected )
462
464
assert_frame_equal (df , original )
463
465
464
- result = df .astype ({'b' : np .float32 , 'c' : 'float32' , 'd' : np .float64 })
466
+ dt2 = dtype_class ({'b' : np .float32 , 'c' : 'float32' , 'd' : np .float64 })
467
+ result = df .astype (dt2 )
465
468
expected = DataFrame ({
466
469
'a' : a ,
467
470
'b' : Series ([0.0 , 1.0 , 2.0 , 3.0 , 4.0 ], dtype = 'float32' ),
@@ -471,19 +474,31 @@ def test_astype_dict(self):
471
474
assert_frame_equal (df , original )
472
475
473
476
# change all columns
474
- assert_frame_equal (df .astype ({'a' : str , 'b' : str , 'c' : str , 'd' : str }),
477
+ dt3 = dtype_class ({'a' : str , 'b' : str , 'c' : str , 'd' : str })
478
+ assert_frame_equal (df .astype (dt3 ),
475
479
df .astype (str ))
476
480
assert_frame_equal (df , original )
477
481
478
482
# error should be raised when using something other than column labels
479
483
# in the keys of the dtype dict
480
- pytest .raises (KeyError , df .astype , {'b' : str , 2 : str })
481
- pytest .raises (KeyError , df .astype , {'e' : str })
484
+ dt4 = dtype_class ({'b' : str , 2 : str })
485
+ dt5 = dtype_class ({'e' : str })
486
+ pytest .raises (KeyError , df .astype , dt4 )
487
+ pytest .raises (KeyError , df .astype , dt5 )
482
488
assert_frame_equal (df , original )
483
489
484
490
# if the dtypes provided are the same as the original dtypes, the
485
491
# resulting DataFrame should be the same as the original DataFrame
486
- equiv = df .astype ({col : df [col ].dtype for col in df .columns })
492
+ dt6 = dtype_class ({col : df [col ].dtype for col in df .columns })
493
+ equiv = df .astype (dt6 )
494
+ assert_frame_equal (df , equiv )
495
+ assert_frame_equal (df , original )
496
+
497
+ # GH 16717
498
+ # if dtypes provided is empty, the resulting DataFrame
499
+ # should be the same as the original DataFrame
500
+ dt7 = dtype_class ({})
501
+ result = df .astype (dt7 )
487
502
assert_frame_equal (df , equiv )
488
503
assert_frame_equal (df , original )
489
504
0 commit comments