Skip to content

Commit 6cb00c3

Browse files
committed
TYP: adopt based{mypy, pyright}
1 parent c1d0e20 commit 6cb00c3

File tree

6 files changed

+257
-294
lines changed

6 files changed

+257
-294
lines changed

codecov.yml

+2
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
comment: false
2+
ignore:
3+
- "src/array_api_extra/_typing"

pixi.lock

+221-282
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+22-2
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,23 @@ array-api-extra = { path = ".", editable = true }
7070

7171
[tool.pixi.feature.lint.dependencies]
7272
pre-commit = "*"
73-
mypy = "*"
7473
pylint = "*"
7574
# import dependencies for mypy:
7675
array-api-strict = "*"
7776
numpy = "*"
7877
pytest = "*"
7978

79+
[tool.pixi.feature.lint.pypi-dependencies]
80+
basedmypy = "*"
81+
basedpyright = "*"
82+
8083
[tool.pixi.feature.lint.tasks]
8184
pre-commit-install = { cmd = "pre-commit install" }
8285
pre-commit = { cmd = "pre-commit run -v --all-files --show-diff-on-failure" }
8386
mypy = { cmd = "mypy", cwd = "." }
8487
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
85-
lint = { depends-on = ["pre-commit", "pylint", "mypy"] }
88+
pyright = { cmd = "basedpyright", cwd = "." }
89+
lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }
8690

8791
[tool.pixi.feature.tests.dependencies]
8892
pytest = ">=6"
@@ -165,13 +169,29 @@ enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
165169
warn_unreachable = true
166170
disallow_untyped_defs = false
167171
disallow_incomplete_defs = false
172+
# array-api#589
173+
disallow_any_expr = false
168174

169175
[[tool.mypy.overrides]]
170176
module = "array_api_extra.*"
171177
disallow_untyped_defs = true
172178
disallow_incomplete_defs = true
173179

174180

181+
# pyright
182+
183+
[tool.pyright]
184+
include = ["src", "tests"]
185+
pythonVersion = "3.10"
186+
pythonPlatform = "All"
187+
typeCheckingMode = "strict"
188+
189+
# array-api#589
190+
reportAny = false
191+
reportExplicitAny = false
192+
reportUnknownMemberType = false
193+
194+
175195
# Ruff
176196

177197
[tool.ruff]

src/array_api_extra/_funcs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3+
import typing
34
import warnings
4-
from typing import TYPE_CHECKING
55

6-
if TYPE_CHECKING:
6+
if typing.TYPE_CHECKING:
77
from ._typing import Array, ModuleType
88

99
__all__ = ["atleast_nd", "cov", "create_diagonal", "expand_dims", "kron", "sinc"]

src/array_api_extra/_typing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from types import ModuleType
44
from typing import Any
55

6-
Array = Any # To be changed to a Protocol later (see array-api#589)
6+
# To be changed to a Protocol later (see array-api#589)
7+
Array = Any # type: ignore[no-any-explicit]
78

89
__all__ = ["Array", "ModuleType"]

tests/test_funcs.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from array_api_extra import atleast_nd, cov, create_diagonal, expand_dims, kron, sinc
1414

1515
if TYPE_CHECKING:
16-
Array = Any # To be changed to a Protocol later (see array-api#589)
16+
# To be changed to a Protocol later (see array-api#589)
17+
Array = Any # type: ignore[no-any-explicit]
1718

1819

1920
class TestAtLeastND:
@@ -131,7 +132,7 @@ def test_1d(self):
131132

132133
@pytest.mark.parametrize("n", range(1, 10))
133134
@pytest.mark.parametrize("offset", range(1, 10))
134-
def test_create_diagonal(self, n, offset):
135+
def test_create_diagonal(self, n: int, offset: int):
135136
# from scipy._lib tests
136137
rng = np.random.default_rng(2347823)
137138
one = xp.asarray(1.0)
@@ -180,9 +181,9 @@ def test_basic(self):
180181
assert_array_equal(kron(a, b, xp=xp), k)
181182

182183
def test_kron_smoke(self):
183-
a = xp.ones([3, 3])
184-
b = xp.ones([3, 3])
185-
k = xp.ones([9, 9])
184+
a = xp.ones((3, 3))
185+
b = xp.ones((3, 3))
186+
k = xp.ones((9, 9))
186187

187188
assert_array_equal(kron(a, b, xp=xp), k)
188189

@@ -197,7 +198,7 @@ def test_kron_smoke(self):
197198
((2, 0, 0, 2), (2, 0, 2)),
198199
],
199200
)
200-
def test_kron_shape(self, shape_a, shape_b):
201+
def test_kron_shape(self, shape_a: tuple[int], shape_b: tuple[int]):
201202
a = xp.ones(shape_a)
202203
b = xp.ones(shape_b)
203204
normalised_shape_a = xp.asarray(
@@ -271,7 +272,7 @@ def test_simple(self):
271272
assert_allclose(w, xp.flip(w, axis=0))
272273

273274
@pytest.mark.parametrize("x", [0, 1 + 3j])
274-
def test_dtype(self, x):
275+
def test_dtype(self, x: int | complex):
275276
with pytest.raises(ValueError, match="real floating data type"):
276277
sinc(xp.asarray(x), xp=xp)
277278

0 commit comments

Comments
 (0)