Skip to content

ENH: lazy_xp_function support for wrapped return values #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
crusaderky opened this issue Apr 21, 2025 · 0 comments · May be fixed by #284
Open

ENH: lazy_xp_function support for wrapped return values #270

crusaderky opened this issue Apr 21, 2025 · 0 comments · May be fixed by #284
Labels
Milestone

Comments

@crusaderky
Copy link
Contributor

scipy.stats.ttest_ind fails when wrapped with lazy_xp_function.
The reason is that the function returns a custom subclass of NamedTuple.
@jax.jit automatically repacks returned lists, tuples, and NamedTuples of arrays, but fails with custom classes.
This however is an artefact specific of lazy_xp_function; in real life, end users will unpack and consume the return value of ttest_ind within the scope of the jit.

In other words, this fails:

from scipy.stats import ttest_ind
xpx.lazy_xp_function(ttest_ind)

def test1(xp):
    x = xp.asarray([.1, .2])
    y = xp.asarray([.3])
    res = ttest_ind(x, y)
    # res = TtestResult(statistic=np.float64(-1.7320508075688765), pvalue=np.float64(0.3333333333333334), df=np.float64(1.0))

FAILED test1.py::test1[jax.numpy] - TypeError: TtestResult.__init__() missing 4 required positional arguments: 'df', 'alternative', 'standard_error', and 'estimate'

as it is equivalent to

>>> from scipy.stats import ttest_ind
>>> jitted = jax.jit(ttest_ind)
>>> jitted(jnp.asarray([.1, .2]), jnp.asarray([.3])
TypeError: TtestResult.__init__() missing 4 required positional arguments: 'df', 'alternative', 'standard_error', and 'estimate'

However, in real-life users will not write the above, but will write instead something like

>>> from scipy.stats import ttest_ind
>>> @jax.jit
>>> def f(x, y):
...     res = ttest_ind(x, y)
...     # Stand-in for some post-processing
...     return res.statistic, res.pvalue, res.df)
>>> f(jnp.asarray([.1, .2]), jnp.asarray([.3])
(Array(-1.7320508, dtype=float32),
 Array(0.3333333, dtype=float32),
 Array(1., dtype=float32))

Proposed design

Use pickle hooks to automatically unpack and repack complex return values.

@lucascolley lucascolley linked a pull request May 11, 2025 that will close this issue
@lucascolley lucascolley linked a pull request May 11, 2025 that will close this issue
@lucascolley lucascolley added this to the 0.8.0 milestone May 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants