Skip to content

ENH: unary functions overhaul; better input validation #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 10, 2025

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 23, 2025

xref #145

  • Rewrite all unary functions with a generator
  • Disallow numpy generics in binary functions, clip, and where
  • Improve error message when the first argument of where is not an Array
  • Test for device mismatches in the inputs of binary functions, clip, and where
  • Test input-output device propagation in where

@@ -168,9 +231,6 @@ def _array_vals():
for d in _floating_dtypes:
yield asarray(1.0, dtype=d)

# Use the latest version of the standard so all functions are included
set_array_api_strict_flags(api_version="2024.12")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with auto-applied fixture

Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Let keep this PR open for a while though, in case somebody has opinions on generating unary functions from a decorator. I personally think this is a good change, but there were concerns in #100

res = xp.where(cond, 1, x2)
assert res.device == device
res = xp.where(cond, x1, 2)
assert res.device == device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the following tests are great. I imaging we'll want to parrot them in array-api-tests at some point.

@crusaderky
Copy link
Contributor Author

Note: this does not close #102 , as Python sneakily callls numpy.ndarray.__radd__. When array_api_strict.Array.__add__ fails. The opposite (LHS is numpy, RHS is array-api-strict) is also impossible to fix without disallowing __array__ and __buffer__.

@ev-br
Copy link
Member

ev-br commented Apr 24, 2025

Note: this does not close #102 , as Python sneakily callls numpy.ndarray.radd. When array_api_strict.Array.add fails. The opposite (LHS is numpy, RHS is array-api-strict) is also impossible to fix without disallowing array and buffer.

Exactly.
I suggest we ignore this for the time being.

@ev-br ev-br added this to the 2.4 milestone May 3, 2025
@crusaderky
Copy link
Contributor Author

@ev-br good to merge now?

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the merge of gh-100, we may as well be consistent, so no worries from my side. Thanks both!

@lucascolley lucascolley merged commit e25b928 into data-apis:main May 10, 2025
19 checks passed
@crusaderky crusaderky deleted the unary_generics branch May 10, 2025 22:31
@lucascolley
Copy link
Member

is the CI failure on main related to this PR? Or was that pre-existing?

@crusaderky
Copy link
Contributor Author

is the CI failure on main related to this PR? Or was that pre-existing?

I'm on holiday and I can't bisect anything, but it looks to me like a scalar arg is being passed on 2023.12. If the issue is in array-api-strict or array-api-tests I can't debug from here, but I don't think it was caused by this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants