-
Notifications
You must be signed in to change notification settings - Fork 4
MAINT: work around arg{min,max} not implemented for booleans #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch_np/_detail/_reductions.py
Outdated
|
||
if tensor.dtype == torch.bool: | ||
# RuntimeError: "argmax_cpu" not implemented for 'Bool' | ||
tensor = tensor.view(torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the equivalence guaranteed in pytorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would cast it to uint8
to be 100% sure, but yeah, I believe it is guaranteed by some work @peterbell10 did some time ago.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A cast is guaranteed to be 0 or 1, but a view could contain any byte value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh and of course the work I did doesn't apply to the torch.compile
stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, if viewing with uint8
we would get something that's non zero, so argmax would then work, but sure, let's do to
, as the compiler should be able to optimise out the copy really
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC argmax always returns the lowest index that's equal to the maximum, so the distinction does matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to unblock, but I believe it'd be better to cast it to uint8
to be 100% sure.
By cast you mean tensor.to or tensor.view? |
view is fine. |
711d3cc
to
1ed3d24
Compare
Based on #63 (comment) and #63 (comment) let's copy to uint8 to be on the safe side. |
No description provided.