Skip to content

Commit 711d3cc

Browse files
committed
MAINT: work around arg{min,max} not implemented for booleans
1 parent 4c5eb7d commit 711d3cc

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

torch_np/_detail/_reductions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,22 @@ def count_nonzero(a, axis=None):
3636

3737
def argmax(tensor, axis=None):
3838
axis = _util.allow_only_single_axis(axis)
39+
40+
if tensor.dtype == torch.bool:
41+
# RuntimeError: "argmax_cpu" not implemented for 'Bool'
42+
tensor = tensor.view(torch.int8)
43+
3944
tensor = torch.argmax(tensor, axis)
4045
return tensor
4146

4247

4348
def argmin(tensor, axis=None):
4449
axis = _util.allow_only_single_axis(axis)
50+
51+
if tensor.dtype == torch.bool:
52+
# RuntimeError: "argmin_cpu" not implemented for 'Bool'
53+
tensor = tensor.view(torch.int8)
54+
4555
tensor = torch.argmin(tensor, axis)
4656
return tensor
4757

torch_np/tests/test_ndarray_methods.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,8 @@ def test_combinations(self, data):
404404
# with suppress_warnings() as sup:
405405
# sup.filter(RuntimeWarning,
406406
# "invalid value encountered in reduce")
407-
if np.asarray(arr).dtype.kind in "cb":
408-
pytest.xfail(
409-
reason="'max_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
410-
)
407+
if np.asarray(arr).dtype.kind in "c":
408+
pytest.xfail(reason="'max_values_cpu' not implemented for 'ComplexDouble'")
411409

412410
val = np.max(arr)
413411

@@ -508,10 +506,8 @@ class TestArgmin:
508506
def test_combinations(self, data):
509507
arr, pos = data
510508

511-
if np.asarray(arr).dtype.kind in "cb":
512-
pytest.xfail(
513-
reason="'min_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
514-
)
509+
if np.asarray(arr).dtype.kind in "c":
510+
pytest.xfail(reason="'min_values_cpu' not implemented for 'ComplexDouble'")
515511

516512
# with suppress_warnings() as sup:
517513
# sup.filter(RuntimeWarning, "invalid value encountered in reduce")

0 commit comments

Comments
 (0)