Skip to content

Commit 478df33

Browse files
authored
Merge pull request #63 from Quansight-Labs/argminmax_bool
MAINT: work around arg{min,max} not implemented for booleans
2 parents 8bdf8b6 + 1ed3d24 commit 478df33

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

torch_np/_detail/_reductions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,22 @@ def count_nonzero(a, axis=None):
3636

3737
def argmax(tensor, axis=None):
3838
axis = _util.allow_only_single_axis(axis)
39+
40+
if tensor.dtype == torch.bool:
41+
# RuntimeError: "argmax_cpu" not implemented for 'Bool'
42+
tensor = tensor.to(torch.uint8)
43+
3944
tensor = torch.argmax(tensor, axis)
4045
return tensor
4146

4247

4348
def argmin(tensor, axis=None):
4449
axis = _util.allow_only_single_axis(axis)
50+
51+
if tensor.dtype == torch.bool:
52+
# RuntimeError: "argmin_cpu" not implemented for 'Bool'
53+
tensor = tensor.to(torch.uint8)
54+
4555
tensor = torch.argmin(tensor, axis)
4656
return tensor
4757

torch_np/tests/test_ndarray_methods.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,8 @@ def test_combinations(self, data):
407407
# with suppress_warnings() as sup:
408408
# sup.filter(RuntimeWarning,
409409
# "invalid value encountered in reduce")
410-
if np.asarray(arr).dtype.kind in "cb":
411-
pytest.xfail(
412-
reason="'max_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
413-
)
410+
if np.asarray(arr).dtype.kind in "c":
411+
pytest.xfail(reason="'max_values_cpu' not implemented for 'ComplexDouble'")
414412

415413
val = np.max(arr)
416414

@@ -511,10 +509,8 @@ class TestArgmin:
511509
def test_combinations(self, data):
512510
arr, pos = data
513511

514-
if np.asarray(arr).dtype.kind in "cb":
515-
pytest.xfail(
516-
reason="'min_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
517-
)
512+
if np.asarray(arr).dtype.kind in "c":
513+
pytest.xfail(reason="'min_values_cpu' not implemented for 'ComplexDouble'")
518514

519515
# with suppress_warnings() as sup:
520516
# sup.filter(RuntimeWarning, "invalid value encountered in reduce")

0 commit comments

Comments
 (0)