@@ -980,24 +980,97 @@ def test_comparison_invalid(self):
980
980
self .assertRaises (TypeError , lambda : x <= y )
981
981
982
982
def test_more_na_comparisons (self ):
983
- left = Series (['a' , np .nan , 'c' ])
984
- right = Series (['a' , np .nan , 'd' ])
983
+ for dtype in [None , object ]:
984
+ left = Series (['a' , np .nan , 'c' ], dtype = dtype )
985
+ right = Series (['a' , np .nan , 'd' ], dtype = dtype )
985
986
986
- result = left == right
987
- expected = Series ([True , False , False ])
988
- assert_series_equal (result , expected )
987
+ result = left == right
988
+ expected = Series ([True , False , False ])
989
+ assert_series_equal (result , expected )
989
990
990
- result = left != right
991
- expected = Series ([False , True , True ])
992
- assert_series_equal (result , expected )
991
+ result = left != right
992
+ expected = Series ([False , True , True ])
993
+ assert_series_equal (result , expected )
993
994
994
- result = left == np .nan
995
- expected = Series ([False , False , False ])
996
- assert_series_equal (result , expected )
995
+ result = left == np .nan
996
+ expected = Series ([False , False , False ])
997
+ assert_series_equal (result , expected )
997
998
998
- result = left != np .nan
999
- expected = Series ([True , True , True ])
1000
- assert_series_equal (result , expected )
999
+ result = left != np .nan
1000
+ expected = Series ([True , True , True ])
1001
+ assert_series_equal (result , expected )
1002
+
1003
+ def test_nat_comparisons (self ):
1004
+ data = [([pd .Timestamp ('2011-01-01' ), pd .NaT ,
1005
+ pd .Timestamp ('2011-01-03' )],
1006
+ [pd .NaT , pd .NaT , pd .Timestamp ('2011-01-03' )]),
1007
+
1008
+ ([pd .Timedelta ('1 days' ), pd .NaT ,
1009
+ pd .Timedelta ('3 days' )],
1010
+ [pd .NaT , pd .NaT , pd .Timedelta ('3 days' )]),
1011
+
1012
+ ([pd .Period ('2011-01' , freq = 'M' ), pd .NaT ,
1013
+ pd .Period ('2011-03' , freq = 'M' )],
1014
+ [pd .NaT , pd .NaT , pd .Period ('2011-03' , freq = 'M' )])]
1015
+
1016
+ # add lhs / rhs switched data
1017
+ data = data + [(r , l ) for l , r in data ]
1018
+
1019
+ for l , r in data :
1020
+ for dtype in [None , object ]:
1021
+ left = Series (l , dtype = dtype )
1022
+
1023
+ # Series, Index
1024
+ for right in [Series (r , dtype = dtype ), Index (r , dtype = dtype )]:
1025
+ expected = Series ([False , False , True ])
1026
+ assert_series_equal (left == right , expected )
1027
+
1028
+ expected = Series ([True , True , False ])
1029
+ assert_series_equal (left != right , expected )
1030
+
1031
+ expected = Series ([False , False , False ])
1032
+ assert_series_equal (left < right , expected )
1033
+
1034
+ expected = Series ([False , False , False ])
1035
+ assert_series_equal (left > right , expected )
1036
+
1037
+ expected = Series ([False , False , True ])
1038
+ assert_series_equal (left >= right , expected )
1039
+
1040
+ expected = Series ([False , False , True ])
1041
+ assert_series_equal (left <= right , expected )
1042
+
1043
+ def test_nat_comparisons_scalar (self ):
1044
+ data = [[pd .Timestamp ('2011-01-01' ), pd .NaT ,
1045
+ pd .Timestamp ('2011-01-03' )],
1046
+
1047
+ [pd .Timedelta ('1 days' ), pd .NaT , pd .Timedelta ('3 days' )],
1048
+
1049
+ [pd .Period ('2011-01' , freq = 'M' ), pd .NaT ,
1050
+ pd .Period ('2011-03' , freq = 'M' )]]
1051
+
1052
+ for l in data :
1053
+ for dtype in [None , object ]:
1054
+ left = Series (l , dtype = dtype )
1055
+
1056
+ expected = Series ([False , False , False ])
1057
+ assert_series_equal (left == pd .NaT , expected )
1058
+ assert_series_equal (pd .NaT == left , expected )
1059
+
1060
+ expected = Series ([True , True , True ])
1061
+ assert_series_equal (left != pd .NaT , expected )
1062
+ assert_series_equal (pd .NaT != left , expected )
1063
+
1064
+ expected = Series ([False , False , False ])
1065
+ assert_series_equal (left < pd .NaT , expected )
1066
+ assert_series_equal (pd .NaT > left , expected )
1067
+ assert_series_equal (left <= pd .NaT , expected )
1068
+ assert_series_equal (pd .NaT >= left , expected )
1069
+
1070
+ assert_series_equal (left > pd .NaT , expected )
1071
+ assert_series_equal (pd .NaT < left , expected )
1072
+ assert_series_equal (left >= pd .NaT , expected )
1073
+ assert_series_equal (pd .NaT <= left , expected )
1001
1074
1002
1075
def test_comparison_different_length (self ):
1003
1076
a = Series (['a' , 'b' , 'c' ])
0 commit comments