|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
| 17 | +import itertools |
| 18 | + |
| 19 | +import numpy as np |
17 | 20 | import pytest
|
18 | 21 |
|
19 | 22 | import dpctl.tensor as dpt
|
@@ -211,3 +214,69 @@ def test_argsort_0d_array():
|
211 | 214 |
|
212 | 215 | x = dpt.asarray(1, dtype="i4")
|
213 | 216 | assert dpt.argsort(x) == 0
|
| 217 | + |
| 218 | + |
| 219 | +@pytest.mark.parametrize( |
| 220 | + "dtype", |
| 221 | + [ |
| 222 | + "f2", |
| 223 | + "f4", |
| 224 | + "f8", |
| 225 | + ], |
| 226 | +) |
| 227 | +def test_sort_real_fp_nan(dtype): |
| 228 | + q = get_queue_or_skip() |
| 229 | + skip_if_dtype_not_supported(dtype, q) |
| 230 | + |
| 231 | + x = dpt.asarray( |
| 232 | + [-0.0, 0.1, dpt.nan, 0.0, -0.1, dpt.nan, 0.2, -0.3], dtype=dtype |
| 233 | + ) |
| 234 | + s = dpt.sort(x) |
| 235 | + |
| 236 | + expected = dpt.asarray( |
| 237 | + [-0.3, -0.1, -0.0, 0.0, 0.1, 0.2, dpt.nan, dpt.nan], dtype=dtype |
| 238 | + ) |
| 239 | + |
| 240 | + assert dpt.allclose(s, expected, equal_nan=True) |
| 241 | + |
| 242 | + s = dpt.sort(x, descending=True) |
| 243 | + |
| 244 | + expected = dpt.asarray( |
| 245 | + [dpt.nan, dpt.nan, 0.2, 0.1, -0.0, 0.0, -0.1, -0.3], dtype=dtype |
| 246 | + ) |
| 247 | + |
| 248 | + assert dpt.allclose(s, expected, equal_nan=True) |
| 249 | + |
| 250 | + |
| 251 | +@pytest.mark.parametrize( |
| 252 | + "dtype", |
| 253 | + [ |
| 254 | + "c8", |
| 255 | + "c16", |
| 256 | + ], |
| 257 | +) |
| 258 | +def test_sort_complex_fp_nan(dtype): |
| 259 | + q = get_queue_or_skip() |
| 260 | + skip_if_dtype_not_supported(dtype, q) |
| 261 | + |
| 262 | + rvs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan] |
| 263 | + ivs = [-0.0, 0.1, 0.0, 0.2, -0.3, dpt.nan] |
| 264 | + |
| 265 | + cv = [] |
| 266 | + for rv in rvs: |
| 267 | + for iv in ivs: |
| 268 | + cv.append(complex(rv, iv)) |
| 269 | + |
| 270 | + inp = dpt.asarray(cv, dtype=dtype) |
| 271 | + s = dpt.sort(inp) |
| 272 | + |
| 273 | + expected = np.sort(dpt.asnumpy(inp)) |
| 274 | + |
| 275 | + assert np.allclose(dpt.asnumpy(s), expected, equal_nan=True) |
| 276 | + |
| 277 | + for i, j in itertools.permutations(range(inp.shape[0]), 2): |
| 278 | + r1 = dpt.asnumpy(dpt.sort(inp[dpt.asarray([i, j])])) |
| 279 | + r2 = np.sort(dpt.asnumpy(inp[dpt.asarray([i, j])])) |
| 280 | + assert np.array_equal( |
| 281 | + r1.view(np.int64), r2.view(np.int64) |
| 282 | + ), f"Failed for {i} and {j}" |
0 commit comments