Skip to content

Commit

Permalink
Merge pull request #336 from lucascolley/runtime-xp
Browse files Browse the repository at this point in the history
feat: `ARRAY_API_TESTS_MODULE` for runtime-defined xp
  • Loading branch information
lucascolley authored Jan 24, 2025
1 parent f7a74a6 commit a882502
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ You need to specify the array library to test. It can be specified via the
$ export ARRAY_API_TESTS_MODULE=array_api_strict
```

To specify a runtime-defined module, define `xp` using the `exec('...')` syntax:

```bash
$ export ARRAY_API_TESTS_MODULE=exec('import quantity_array, numpy; xp = quantity_array.quantity_namespace(numpy)')
```

Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.

### Specifying the API version
Expand Down
34 changes: 21 additions & 13 deletions array_api_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@
# You can comment the following out and instead import the specific array module
# you want to test, e.g. `import array_api_strict as xp`.
if "ARRAY_API_TESTS_MODULE" in os.environ:
xp_name = os.environ["ARRAY_API_TESTS_MODULE"]
_module, _sub = xp_name, None
if "." in xp_name:
_module, _sub = xp_name.split(".", 1)
xp = import_module(_module)
if _sub:
try:
xp = getattr(xp, _sub)
except AttributeError:
# _sub may be a submodule that needs to be imported. WE can't
# do this in every case because some array modules are not
# submodules that can be imported (like mxnet.nd).
xp = import_module(xp_name)
env_var = os.environ["ARRAY_API_TESTS_MODULE"]
if env_var.startswith("exec('") and env_var.endswith("')"):
script = env_var[6:][:-2]
namespace = {}
exec(script, namespace)
xp = namespace["xp"]
xp_name = xp.__name__
else:
xp_name = env_var
_module, _sub = xp_name, None
if "." in xp_name:
_module, _sub = xp_name.split(".", 1)
xp = import_module(_module)
if _sub:
try:
xp = getattr(xp, _sub)
except AttributeError:
# _sub may be a submodule that needs to be imported. We can't
# do this in every case because some array modules are not
# submodules that can be imported (like mxnet.nd).
xp = import_module(xp_name)
else:
raise RuntimeError(
"No array module specified - either edit __init__.py or set the "
Expand Down

0 comments on commit a882502

Please sign in to comment.