File tree 2 files changed +14
-8
lines changed
2 files changed +14
-8
lines changed Original file line number Diff line number Diff line change @@ -36,12 +36,22 @@ def count_nonzero(a, axis=None):
36
36
37
37
def argmax (tensor , axis = None ):
38
38
axis = _util .allow_only_single_axis (axis )
39
+
40
+ if tensor .dtype == torch .bool :
41
+ # RuntimeError: "argmax_cpu" not implemented for 'Bool'
42
+ tensor = tensor .to (torch .uint8 )
43
+
39
44
tensor = torch .argmax (tensor , axis )
40
45
return tensor
41
46
42
47
43
48
def argmin (tensor , axis = None ):
44
49
axis = _util .allow_only_single_axis (axis )
50
+
51
+ if tensor .dtype == torch .bool :
52
+ # RuntimeError: "argmin_cpu" not implemented for 'Bool'
53
+ tensor = tensor .to (torch .uint8 )
54
+
45
55
tensor = torch .argmin (tensor , axis )
46
56
return tensor
47
57
Original file line number Diff line number Diff line change @@ -407,10 +407,8 @@ def test_combinations(self, data):
407
407
# with suppress_warnings() as sup:
408
408
# sup.filter(RuntimeWarning,
409
409
# "invalid value encountered in reduce")
410
- if np .asarray (arr ).dtype .kind in "cb" :
411
- pytest .xfail (
412
- reason = "'max_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
413
- )
410
+ if np .asarray (arr ).dtype .kind in "c" :
411
+ pytest .xfail (reason = "'max_values_cpu' not implemented for 'ComplexDouble'" )
414
412
415
413
val = np .max (arr )
416
414
@@ -511,10 +509,8 @@ class TestArgmin:
511
509
def test_combinations (self , data ):
512
510
arr , pos = data
513
511
514
- if np .asarray (arr ).dtype .kind in "cb" :
515
- pytest .xfail (
516
- reason = "'min_values_cpu' not implemented for 'ComplexDouble', 'Bool'"
517
- )
512
+ if np .asarray (arr ).dtype .kind in "c" :
513
+ pytest .xfail (reason = "'min_values_cpu' not implemented for 'ComplexDouble'" )
518
514
519
515
# with suppress_warnings() as sup:
520
516
# sup.filter(RuntimeWarning, "invalid value encountered in reduce")
You can’t perform that action at this time.
0 commit comments