@@ -126,3 +126,57 @@ def test_argsort_1d(dtype):
126
126
127
127
s1_idx = dpt .argsort (inp , descending = True )
128
128
assert dpt .all (inp [s1_idx [:- 1 ]] >= inp [s1_idx [1 :]])
129
+
130
+
131
+ def test_sort_validation ():
132
+ with pytest .raises (TypeError ):
133
+ dpt .sort (dict ())
134
+
135
+
136
+ def test_argsort_validation ():
137
+ with pytest .raises (TypeError ):
138
+ dpt .argsort (dict ())
139
+
140
+
141
+ def test_sort_axis0 ():
142
+ get_queue_or_skip ()
143
+
144
+ n , m = 200 , 30
145
+ xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
146
+ x = dpt .reshape (xf , (n , m ))
147
+ s = dpt .sort (x , axis = 0 )
148
+
149
+ assert dpt .all (s [:- 1 , :] <= s [1 :, :])
150
+
151
+
152
+ def test_argsort_axis0 ():
153
+ get_queue_or_skip ()
154
+
155
+ n , m = 200 , 30
156
+ xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
157
+ x = dpt .reshape (xf , (n , m ))
158
+ idx = dpt .argsort (x , axis = 0 )
159
+
160
+ s = x [idx , dpt .arange (m )[dpt .newaxis , :]]
161
+
162
+ assert dpt .all (s [:- 1 , :] <= s [1 :, :])
163
+
164
+
165
+ def test_sort_strided ():
166
+ get_queue_or_skip ()
167
+
168
+ x_orig = dpt .arange (100 , dtype = "i4" )
169
+ x_flipped = dpt .flip (x_orig , axis = 0 )
170
+ s = dpt .sort (x_flipped )
171
+
172
+ assert dpt .all (s == x_orig )
173
+
174
+
175
+ def test_argsort_strided ():
176
+ get_queue_or_skip ()
177
+
178
+ x_orig = dpt .arange (100 , dtype = "i4" )
179
+ x_flipped = dpt .flip (x_orig , axis = 0 )
180
+ idx = dpt .argsort (x_flipped )
181
+
182
+ assert dpt .all (x_flipped [idx ] == x_orig )
0 commit comments