Skip to content

Commit 36d15bb

Browse files
committed
Fix __all__ not getting updated with reset_array_api_strict_flags()
1 parent 718f15b commit 36d15bb

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

array_api_strict/_flags.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ def reset_array_api_strict_flags():
262262
BOOLEAN_INDEXING = True
263263
DATA_DEPENDENT_SHAPES = True
264264
ENABLED_EXTENSIONS = default_extensions
265-
265+
array_api_strict.__all__[:] = sorted(set(ENABLED_EXTENSIONS) |
266+
set(array_api_strict.__all__) -
267+
set(default_extensions))
266268

267269
class ArrayAPIStrictFlags:
268270
"""

array_api_strict/tests/test_flags.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,15 @@ def test_disabled_extensions():
371371
assert 'linalg' not in ns
372372
assert 'fft' not in ns
373373

374+
reset_array_api_strict_flags()
375+
assert 'linalg' in xp.__all__
376+
assert 'fft' in xp.__all__
377+
xp.linalg # No error
378+
xp.fft # No error
379+
ns = {}
380+
exec('from array_api_strict import *', ns)
381+
assert 'linalg' in ns
382+
assert 'fft' in ns
374383

375384
def test_environment_variables():
376385
# Test that the environment variables work as expected

0 commit comments

Comments
 (0)