diff --git a/torch_np/_detail/_reductions.py b/torch_np/_detail/_reductions.py index d7688787..ac093699 100644 --- a/torch_np/_detail/_reductions.py +++ b/torch_np/_detail/_reductions.py @@ -36,12 +36,22 @@ def count_nonzero(a, axis=None): def argmax(tensor, axis=None): axis = _util.allow_only_single_axis(axis) + + if tensor.dtype == torch.bool: + # RuntimeError: "argmax_cpu" not implemented for 'Bool' + tensor = tensor.to(torch.uint8) + tensor = torch.argmax(tensor, axis) return tensor def argmin(tensor, axis=None): axis = _util.allow_only_single_axis(axis) + + if tensor.dtype == torch.bool: + # RuntimeError: "argmin_cpu" not implemented for 'Bool' + tensor = tensor.to(torch.uint8) + tensor = torch.argmin(tensor, axis) return tensor diff --git a/torch_np/tests/test_ndarray_methods.py b/torch_np/tests/test_ndarray_methods.py index 0e1fced8..db271aa0 100644 --- a/torch_np/tests/test_ndarray_methods.py +++ b/torch_np/tests/test_ndarray_methods.py @@ -404,10 +404,8 @@ def test_combinations(self, data): # with suppress_warnings() as sup: # sup.filter(RuntimeWarning, # "invalid value encountered in reduce") - if np.asarray(arr).dtype.kind in "cb": - pytest.xfail( - reason="'max_values_cpu' not implemented for 'ComplexDouble', 'Bool'" - ) + if np.asarray(arr).dtype.kind in "c": + pytest.xfail(reason="'max_values_cpu' not implemented for 'ComplexDouble'") val = np.max(arr) @@ -508,10 +506,8 @@ class TestArgmin: def test_combinations(self, data): arr, pos = data - if np.asarray(arr).dtype.kind in "cb": - pytest.xfail( - reason="'min_values_cpu' not implemented for 'ComplexDouble', 'Bool'" - ) + if np.asarray(arr).dtype.kind in "c": + pytest.xfail(reason="'min_values_cpu' not implemented for 'ComplexDouble'") # with suppress_warnings() as sup: # sup.filter(RuntimeWarning, "invalid value encountered in reduce")