|
1 | 1 | import subprocess
|
2 | 2 | import sys
|
| 3 | +import warnings |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 | import pytest
|
@@ -57,13 +58,24 @@ def test_array_namespace_errors():
|
57 | 58 | pytest.raises(TypeError, lambda: array_namespace((x, x)))
|
58 | 59 | pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
|
59 | 60 |
|
60 |
| - |
61 | 61 | def test_array_namespace_errors_torch():
|
62 | 62 | y = torch.asarray([1, 2])
|
63 | 63 | x = np.asarray([1, 2])
|
64 | 64 | pytest.raises(TypeError, lambda: array_namespace(x, y))
|
65 |
| - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12")) |
66 | 65 |
|
| 66 | +def test_api_version(): |
| 67 | + x = np.asarray([1, 2]) |
| 68 | + np_ = import_("numpy", wrapper=True) |
| 69 | + assert array_namespace(x, api_version="2022.12") == np_ |
| 70 | + assert array_namespace(x, api_version=None) == np_ |
| 71 | + assert array_namespace(x) == np_ |
| 72 | + # Should issue a warning |
| 73 | + with warnings.catch_warnings(record=True) as w: |
| 74 | + assert array_namespace(x, api_version="2021.12") == np_ |
| 75 | + assert len(w) == 1 |
| 76 | + assert "2021.12" in str(w[0].message) |
| 77 | + |
| 78 | + pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) |
67 | 79 |
|
68 | 80 | def test_get_namespace():
|
69 | 81 | # Backwards compatible wrapper
|
|
0 commit comments