diff --git a/README.md b/README.md index 928a6771..17131560 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ You need to specify the array library to test. It can be specified via the $ export ARRAY_API_TESTS_MODULE=array_api_strict ``` +To specify a runtime-defined module, define `xp` using the `exec('...')` syntax: + +```bash +$ export ARRAY_API_TESTS_MODULE=exec('import quantity_array, numpy; xp = quantity_array.quantity_namespace(numpy)') +``` + Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`. ### Specifying the API version diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 4e0c340f..f3805b56 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -13,19 +13,27 @@ # You can comment the following out and instead import the specific array module # you want to test, e.g. `import array_api_strict as xp`. if "ARRAY_API_TESTS_MODULE" in os.environ: - xp_name = os.environ["ARRAY_API_TESTS_MODULE"] - _module, _sub = xp_name, None - if "." in xp_name: - _module, _sub = xp_name.split(".", 1) - xp = import_module(_module) - if _sub: - try: - xp = getattr(xp, _sub) - except AttributeError: - # _sub may be a submodule that needs to be imported. WE can't - # do this in every case because some array modules are not - # submodules that can be imported (like mxnet.nd). - xp = import_module(xp_name) + env_var = os.environ["ARRAY_API_TESTS_MODULE"] + if env_var.startswith("exec('") and env_var.endswith("')"): + script = env_var[6:][:-2] + namespace = {} + exec(script, namespace) + xp = namespace["xp"] + xp_name = xp.__name__ + else: + xp_name = env_var + _module, _sub = xp_name, None + if "." in xp_name: + _module, _sub = xp_name.split(".", 1) + xp = import_module(_module) + if _sub: + try: + xp = getattr(xp, _sub) + except AttributeError: + # _sub may be a submodule that needs to be imported. We can't + # do this in every case because some array modules are not + # submodules that can be imported (like mxnet.nd). + xp = import_module(xp_name) else: raise RuntimeError( "No array module specified - either edit __init__.py or set the "