Skip to content

Commit 9dec816

Browse files
Expand tests for sorting to improve coverage
1 parent 8053991 commit 9dec816

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,57 @@ def test_argsort_1d(dtype):
126126

127127
s1_idx = dpt.argsort(inp, descending=True)
128128
assert dpt.all(inp[s1_idx[:-1]] >= inp[s1_idx[1:]])
129+
130+
131+
def test_sort_validation():
132+
with pytest.raises(TypeError):
133+
dpt.sort(dict())
134+
135+
136+
def test_argsort_validation():
137+
with pytest.raises(TypeError):
138+
dpt.argsort(dict())
139+
140+
141+
def test_sort_axis0():
142+
get_queue_or_skip()
143+
144+
n, m = 200, 30
145+
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
146+
x = dpt.reshape(xf, (n, m))
147+
s = dpt.sort(x, axis=0)
148+
149+
assert dpt.all(s[:-1, :] <= s[1:, :])
150+
151+
152+
def test_argsort_axis0():
153+
get_queue_or_skip()
154+
155+
n, m = 200, 30
156+
xf = dpt.arange(n * m, 0, step=-1, dtype="i4")
157+
x = dpt.reshape(xf, (n, m))
158+
idx = dpt.argsort(x, axis=0)
159+
160+
s = x[idx, dpt.arange(m)[dpt.newaxis, :]]
161+
162+
assert dpt.all(s[:-1, :] <= s[1:, :])
163+
164+
165+
def test_sort_strided():
166+
get_queue_or_skip()
167+
168+
x_orig = dpt.arange(100, dtype="i4")
169+
x_flipped = dpt.flip(x_orig, axis=0)
170+
s = dpt.sort(x_flipped)
171+
172+
assert dpt.all(s == x_orig)
173+
174+
175+
def test_argsort_strided():
176+
get_queue_or_skip()
177+
178+
x_orig = dpt.arange(100, dtype="i4")
179+
x_flipped = dpt.flip(x_orig, axis=0)
180+
idx = dpt.argsort(x_flipped)
181+
182+
assert dpt.all(x_flipped[idx] == x_orig)

0 commit comments

Comments
 (0)