@@ -245,15 +245,21 @@ def all_names(mod):
245
245
return list (objs )
246
246
247
247
248
+ def get_mod (library , module , * , compat ):
249
+ if compat :
250
+ library = f"array_api_compat.{ library } "
251
+ xp = pytest .importorskip (library )
252
+ return getattr (xp , module ) if module else xp
253
+
254
+
248
255
@pytest .mark .parametrize ("func" , [all_names , dir ])
249
256
@pytest .mark .parametrize ("module" , list (NAMES ))
250
257
@pytest .mark .parametrize ("library" , wrapped_libraries )
251
258
def test_array_api_names (library , module , func ):
252
259
"""Test that __all__ and dir() aren't missing any exports
253
260
dictated by the Standard.
254
261
"""
255
- xp = pytest .importorskip (f"array_api_compat.{ library } " )
256
- mod = getattr (xp , module ) if module else xp
262
+ mod = get_mod (library , module , compat = True )
257
263
missing = set (NAMES [module ]) - set (func (mod ))
258
264
xfail = set (XFAILS .get ((library , module ), []))
259
265
xpass = xfail - missing
@@ -269,10 +275,8 @@ def test_compat_doesnt_hide_names(library, module, func):
269
275
"""The base namespace can have more names than the ones explicitly exported
270
276
by array-api-compat. Test that we're not suppressing them.
271
277
"""
272
- bare_xp = pytest .importorskip (library )
273
- compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
274
- bare_mod = getattr (bare_xp , module ) if module else bare_xp
275
- compat_mod = getattr (compat_xp , module ) if module else compat_xp
278
+ bare_mod = get_mod (library , module , compat = False )
279
+ compat_mod = get_mod (library , module , compat = True )
276
280
277
281
missing = set (func (bare_mod )) - set (func (compat_mod ))
278
282
missing = {name for name in missing if not name .startswith ("_" )}
@@ -286,10 +290,8 @@ def test_compat_doesnt_add_names(library, module, func):
286
290
"""Test that array-api-compat isn't adding names to the namespace
287
291
besides those defined by the Array API Standard.
288
292
"""
289
- bare_xp = pytest .importorskip (library )
290
- compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
291
- bare_mod = getattr (bare_xp , module ) if module else bare_xp
292
- compat_mod = getattr (compat_xp , module ) if module else compat_xp
293
+ bare_mod = get_mod (library , module , compat = False )
294
+ compat_mod = get_mod (library , module , compat = True )
293
295
294
296
aapi_names = set (NAMES [module ])
295
297
spurious = set (func (compat_mod )) - set (func (bare_mod )) - aapi_names - {"__all__" }
0 commit comments