@@ -87,8 +87,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
87
87
graph.
88
88
jax_jit : bool, optional
89
89
Set to True to replace `func` with ``jax.jit(func)`` after calling the
90
- :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False if
91
- `func` is only compatible with eager (non-jitted) JAX. Default: True.
90
+ :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False
91
+ if `func` is only compatible with eager (non-jitted) JAX. Default: True.
92
92
static_argnums : int | Sequence[int], optional
93
93
Passed to jax.jit. Positional arguments to treat as static (compile-time
94
94
constant). Default: infer from `static_argnames` using
@@ -113,7 +113,7 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
113
113
def test_myfunc(xp):
114
114
a = xp.asarray([1, 2])
115
115
# When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)`
116
- # When xp=dask.array, crash on compute() or persist()
116
+ # When xp=dask.array, crash on compute() or persist()
117
117
b = myfunc(a)
118
118
119
119
Notes
@@ -150,8 +150,8 @@ def patch_lazy_xp_functions(
150
150
:func:`lazy_xp_function` in the globals of the module that defines the current test
151
151
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
152
152
153
- If ``xp==dask.array``, wrap the functions with a decorator that disables ``compute()``
154
- and ``persist()``.
153
+ If ``xp==dask.array``, wrap the functions with a decorator that disables
154
+ ``compute()`` and ``persist()``.
155
155
156
156
This function should be typically called by your library's `xp` fixture that runs
157
157
tests on multiple backends::
0 commit comments