Skip to content

Commit 9345e5f

Browse files
committed
test
1 parent 20740af commit 9345e5f

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tests/test_all.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,21 @@ def all_names(mod):
245245
return list(objs)
246246

247247

248+
def get_mod(library, module, *, compat):
249+
if compat:
250+
library = f"array_api_compat.{library}"
251+
xp = pytest.importorskip(library)
252+
return getattr(xp, module) if module else xp
253+
254+
248255
@pytest.mark.parametrize("func", [all_names, dir])
249256
@pytest.mark.parametrize("module", list(NAMES))
250257
@pytest.mark.parametrize("library", wrapped_libraries)
251258
def test_array_api_names(library, module, func):
252259
"""Test that __all__ and dir() aren't missing any exports
253260
dictated by the Standard.
254261
"""
255-
xp = pytest.importorskip(f"array_api_compat.{library}")
256-
mod = getattr(xp, module) if module else xp
262+
mod = get_mod(library, module, compat=True)
257263
missing = set(NAMES[module]) - set(func(mod))
258264
xfail = set(XFAILS.get((library, module), []))
259265
xpass = xfail - missing
@@ -269,10 +275,8 @@ def test_compat_doesnt_hide_names(library, module, func):
269275
"""The base namespace can have more names than the ones explicitly exported
270276
by array-api-compat. Test that we're not suppressing them.
271277
"""
272-
bare_xp = pytest.importorskip(library)
273-
compat_xp = pytest.importorskip(f"array_api_compat.{library}")
274-
bare_mod = getattr(bare_xp, module) if module else bare_xp
275-
compat_mod = getattr(compat_xp, module) if module else compat_xp
278+
bare_mod = get_mod(library, module, compat=False)
279+
compat_mod = get_mod(library, module, compat=True)
276280

277281
missing = set(func(bare_mod)) - set(func(compat_mod))
278282
missing = {name for name in missing if not name.startswith("_")}
@@ -286,10 +290,8 @@ def test_compat_doesnt_add_names(library, module, func):
286290
"""Test that array-api-compat isn't adding names to the namespace
287291
besides those defined by the Array API Standard.
288292
"""
289-
bare_xp = pytest.importorskip(library)
290-
compat_xp = pytest.importorskip(f"array_api_compat.{library}")
291-
bare_mod = getattr(bare_xp, module) if module else bare_xp
292-
compat_mod = getattr(compat_xp, module) if module else compat_xp
293+
bare_mod = get_mod(library, module, compat=False)
294+
compat_mod = get_mod(library, module, compat=True)
293295

294296
aapi_names = set(NAMES[module])
295297
spurious = set(func(compat_mod)) - set(func(bare_mod)) - aapi_names - {"__all__"}

0 commit comments

Comments
 (0)