2
2
import pytest
3
3
4
4
import pandas as pd
5
+ from pandas import DataFrame , Series , date_range
5
6
import pandas ._testing as tm
6
7
7
8
8
9
class TestDataFrameTruncate :
9
- def test_truncate (self , datetime_frame ):
10
+ def test_truncate (self , datetime_frame , frame_or_series ):
10
11
ts = datetime_frame [::3 ]
12
+ if frame_or_series is Series :
13
+ ts = ts .iloc [:, 0 ]
11
14
12
15
start , end = datetime_frame .index [3 ], datetime_frame .index [6 ]
13
16
@@ -16,34 +19,41 @@ def test_truncate(self, datetime_frame):
16
19
17
20
# neither specified
18
21
truncated = ts .truncate ()
19
- tm .assert_frame_equal (truncated , ts )
22
+ tm .assert_equal (truncated , ts )
20
23
21
24
# both specified
22
25
expected = ts [1 :3 ]
23
26
24
27
truncated = ts .truncate (start , end )
25
- tm .assert_frame_equal (truncated , expected )
28
+ tm .assert_equal (truncated , expected )
26
29
27
30
truncated = ts .truncate (start_missing , end_missing )
28
- tm .assert_frame_equal (truncated , expected )
31
+ tm .assert_equal (truncated , expected )
29
32
30
33
# start specified
31
34
expected = ts [1 :]
32
35
33
36
truncated = ts .truncate (before = start )
34
- tm .assert_frame_equal (truncated , expected )
37
+ tm .assert_equal (truncated , expected )
35
38
36
39
truncated = ts .truncate (before = start_missing )
37
- tm .assert_frame_equal (truncated , expected )
40
+ tm .assert_equal (truncated , expected )
38
41
39
42
# end specified
40
43
expected = ts [:3 ]
41
44
42
45
truncated = ts .truncate (after = end )
43
- tm .assert_frame_equal (truncated , expected )
46
+ tm .assert_equal (truncated , expected )
44
47
45
48
truncated = ts .truncate (after = end_missing )
46
- tm .assert_frame_equal (truncated , expected )
49
+ tm .assert_equal (truncated , expected )
50
+
51
+ # corner case, empty series/frame returned
52
+ truncated = ts .truncate (after = ts .index [0 ] - ts .index .freq )
53
+ assert len (truncated ) == 0
54
+
55
+ truncated = ts .truncate (before = ts .index [- 1 ] + ts .index .freq )
56
+ assert len (truncated ) == 0
47
57
48
58
msg = "Truncate: 2000-01-06 00:00:00 must be after 2000-02-04 00:00:00"
49
59
with pytest .raises (ValueError , match = msg ):
@@ -57,25 +67,35 @@ def test_truncate_copy(self, datetime_frame):
57
67
truncated .values [:] = 5.0
58
68
assert not (datetime_frame .values [5 :11 ] == 5 ).any ()
59
69
60
- def test_truncate_nonsortedindex (self ):
70
+ def test_truncate_nonsortedindex (self , frame_or_series ):
61
71
# GH#17935
62
72
63
- df = pd .DataFrame ({"A" : ["a" , "b" , "c" , "d" , "e" ]}, index = [5 , 3 , 2 , 9 , 0 ])
73
+ obj = DataFrame ({"A" : ["a" , "b" , "c" , "d" , "e" ]}, index = [5 , 3 , 2 , 9 , 0 ])
74
+ if frame_or_series is Series :
75
+ obj = obj ["A" ]
76
+
64
77
msg = "truncate requires a sorted index"
65
78
with pytest .raises (ValueError , match = msg ):
66
- df .truncate (before = 3 , after = 9 )
79
+ obj .truncate (before = 3 , after = 9 )
80
+
81
+ def test_sort_values_nonsortedindex (self ):
82
+ # TODO: belongs elsewhere?
67
83
68
- rng = pd . date_range ("2011-01-01" , "2012-01-01" , freq = "W" )
69
- ts = pd . DataFrame (
84
+ rng = date_range ("2011-01-01" , "2012-01-01" , freq = "W" )
85
+ ts = DataFrame (
70
86
{"A" : np .random .randn (len (rng )), "B" : np .random .randn (len (rng ))}, index = rng
71
87
)
88
+
72
89
msg = "truncate requires a sorted index"
73
90
with pytest .raises (ValueError , match = msg ):
74
91
ts .sort_values ("A" , ascending = False ).truncate (
75
92
before = "2011-11" , after = "2011-12"
76
93
)
77
94
78
- df = pd .DataFrame (
95
+ def test_truncate_nonsortedindex_axis1 (self ):
96
+ # GH#17935
97
+
98
+ df = DataFrame (
79
99
{
80
100
3 : np .random .randn (5 ),
81
101
20 : np .random .randn (5 ),
@@ -93,27 +113,34 @@ def test_truncate_nonsortedindex(self):
93
113
[(1 , 2 , [2 , 1 ]), (None , 2 , [2 , 1 , 0 ]), (1 , None , [3 , 2 , 1 ])],
94
114
)
95
115
@pytest .mark .parametrize ("klass" , [pd .Int64Index , pd .DatetimeIndex ])
96
- def test_truncate_decreasing_index (self , before , after , indices , klass ):
116
+ def test_truncate_decreasing_index (
117
+ self , before , after , indices , klass , frame_or_series
118
+ ):
97
119
# https://github.com/pandas-dev/pandas/issues/33756
98
120
idx = klass ([3 , 2 , 1 , 0 ])
99
121
if klass is pd .DatetimeIndex :
100
122
before = pd .Timestamp (before ) if before is not None else None
101
123
after = pd .Timestamp (after ) if after is not None else None
102
124
indices = [pd .Timestamp (i ) for i in indices ]
103
- values = pd . DataFrame (range (len (idx )), index = idx )
125
+ values = frame_or_series (range (len (idx )), index = idx )
104
126
result = values .truncate (before = before , after = after )
105
127
expected = values .loc [indices ]
106
- tm .assert_frame_equal (result , expected )
128
+ tm .assert_equal (result , expected )
107
129
108
- def test_truncate_multiindex (self ):
130
+ def test_truncate_multiindex (self , frame_or_series ):
109
131
# GH 34564
110
132
mi = pd .MultiIndex .from_product ([[1 , 2 , 3 , 4 ], ["A" , "B" ]], names = ["L1" , "L2" ])
111
- s1 = pd .DataFrame (range (mi .shape [0 ]), index = mi , columns = ["col" ])
133
+ s1 = DataFrame (range (mi .shape [0 ]), index = mi , columns = ["col" ])
134
+ if frame_or_series is Series :
135
+ s1 = s1 ["col" ]
136
+
112
137
result = s1 .truncate (before = 2 , after = 3 )
113
138
114
- df = pd . DataFrame .from_dict (
139
+ df = DataFrame .from_dict (
115
140
{"L1" : [2 , 2 , 3 , 3 ], "L2" : ["A" , "B" , "A" , "B" ], "col" : [2 , 3 , 4 , 5 ]}
116
141
)
117
142
expected = df .set_index (["L1" , "L2" ])
143
+ if frame_or_series is Series :
144
+ expected = expected ["col" ]
118
145
119
- tm .assert_frame_equal (result , expected )
146
+ tm .assert_equal (result , expected )
0 commit comments