Skip to content

Commit 5469832

Browse files
Adding tests for sorting of FP arrays with NaNs
1 parent a3d0d08 commit 5469832

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import itertools
18+
19+
import numpy as np
1720
import pytest
1821

1922
import dpctl.tensor as dpt
@@ -211,3 +214,69 @@ def test_argsort_0d_array():
211214

212215
x = dpt.asarray(1, dtype="i4")
213216
assert dpt.argsort(x) == 0
217+
218+
219+
@pytest.mark.parametrize(
220+
"dtype",
221+
[
222+
"f2",
223+
"f4",
224+
"f8",
225+
],
226+
)
227+
def test_sort_real_fp_nan(dtype):
228+
q = get_queue_or_skip()
229+
skip_if_dtype_not_supported(dtype, q)
230+
231+
x = dpt.asarray(
232+
[-0.0, 0.1, dpt.nan, 0.0, -0.1, dpt.nan, 0.2, -0.3], dtype=dtype
233+
)
234+
s = dpt.sort(x)
235+
236+
expected = dpt.asarray(
237+
[-0.3, -0.1, -0.0, 0.0, 0.1, 0.2, dpt.nan, dpt.nan], dtype=dtype
238+
)
239+
240+
assert dpt.allclose(s, expected, equal_nan=True)
241+
242+
s = dpt.sort(x, descending=True)
243+
244+
expected = dpt.asarray(
245+
[dpt.nan, dpt.nan, 0.2, 0.1, -0.0, 0.0, -0.1, -0.3], dtype=dtype
246+
)
247+
248+
assert dpt.allclose(s, expected, equal_nan=True)
249+
250+
251+
@pytest.mark.parametrize(
252+
"dtype",
253+
[
254+
"c8",
255+
"c16",
256+
],
257+
)
258+
def test_sort_complex_fp_nan(dtype):
259+
q = get_queue_or_skip()
260+
skip_if_dtype_not_supported(dtype, q)
261+
262+
rvs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan]
263+
ivs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan]
264+
265+
cv = []
266+
for rv in rvs:
267+
for iv in ivs:
268+
cv.append(complex(rv, iv))
269+
270+
inp = dpt.asarray(cv, dtype=dtype)
271+
s = dpt.sort(inp)
272+
273+
expected = np.sort(dpt.asnumpy(inp))
274+
275+
assert np.allclose(dpt.asnumpy(s), expected, equal_nan=True)
276+
277+
for i, j in itertools.permutations(range(inp.shape[0]), 2):
278+
r1 = dpt.asnumpy(dpt.sort(inp[dpt.asarray([i, j])]))
279+
r2 = np.sort(dpt.asnumpy(inp[dpt.asarray([i, j])]))
280+
assert np.array_equal(
281+
r1.view(np.int64), r2.view(np.int64)
282+
), f"Failed for {i} and {j}"

0 commit comments

Comments
 (0)