Skip to content

Commit b88a9f1

Browse files
Tests for set functions
1 parent 9dec816 commit b88a9f1

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
21+
22+
23+
@pytest.mark.parametrize(
24+
"dtype",
25+
[
26+
"i1",
27+
"u1",
28+
"i2",
29+
"u2",
30+
"i4",
31+
"u4",
32+
"i8",
33+
"u8",
34+
"f2",
35+
"f4",
36+
"f8",
37+
"c8",
38+
"c16",
39+
],
40+
)
41+
def test_unique_values(dtype):
42+
q = get_queue_or_skip()
43+
skip_if_dtype_not_supported(dtype, q)
44+
45+
n, roll = 10000, 734
46+
inp = dpt.roll(
47+
dpt.concat((dpt.ones(n, dtype=dtype), dpt.zeros(n, dtype=dtype))),
48+
roll,
49+
)
50+
51+
uv = dpt.unique_values(inp)
52+
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
53+
54+
55+
@pytest.mark.parametrize(
56+
"dtype",
57+
[
58+
"i1",
59+
"u1",
60+
"i2",
61+
"u2",
62+
"i4",
63+
"u4",
64+
"i8",
65+
"u8",
66+
"f2",
67+
"f4",
68+
"f8",
69+
"c8",
70+
"c16",
71+
],
72+
)
73+
def test_unique_counts(dtype):
74+
q = get_queue_or_skip()
75+
skip_if_dtype_not_supported(dtype, q)
76+
77+
n, roll = 10000, 734
78+
inp = dpt.roll(
79+
dpt.concat((dpt.ones(n, dtype=dtype), dpt.zeros(n, dtype=dtype))),
80+
roll,
81+
)
82+
83+
uv, uv_counts = dpt.unique_counts(inp)
84+
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
85+
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))
86+
87+
88+
@pytest.mark.parametrize(
89+
"dtype",
90+
[
91+
"i1",
92+
"u1",
93+
"i2",
94+
"u2",
95+
"i4",
96+
"u4",
97+
"i8",
98+
"u8",
99+
"f2",
100+
"f4",
101+
"f8",
102+
"c8",
103+
"c16",
104+
],
105+
)
106+
def test_unique_inverse(dtype):
107+
q = get_queue_or_skip()
108+
skip_if_dtype_not_supported(dtype, q)
109+
110+
n, roll = 10000, 734
111+
inp = dpt.roll(
112+
dpt.concat((dpt.ones(n, dtype=dtype), dpt.zeros(n, dtype=dtype))),
113+
roll,
114+
)
115+
116+
uv, inv = dpt.unique_inverse(inp)
117+
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
118+
assert dpt.all(inp == uv[inv])
119+
120+
121+
@pytest.mark.parametrize(
122+
"dtype",
123+
[
124+
"i1",
125+
"u1",
126+
"i2",
127+
"u2",
128+
"i4",
129+
"u4",
130+
"i8",
131+
"u8",
132+
"f2",
133+
"f4",
134+
"f8",
135+
"c8",
136+
"c16",
137+
],
138+
)
139+
def test_unique_all(dtype):
140+
q = get_queue_or_skip()
141+
skip_if_dtype_not_supported(dtype, q)
142+
143+
n, roll = 10000, 734
144+
inp = dpt.roll(
145+
dpt.concat((dpt.ones(n, dtype=dtype), dpt.zeros(n, dtype=dtype))),
146+
roll,
147+
)
148+
149+
uv, ind, inv, uv_counts = dpt.unique_all(inp)
150+
assert dpt.all(uv == dpt.arange(2, dtype=dtype))
151+
assert dpt.all(uv == inp[ind])
152+
assert dpt.all(inp == uv[inv])
153+
assert dpt.all(uv_counts == dpt.full(2, n, dtype=uv_counts.dtype))

0 commit comments

Comments
 (0)